Commit e56de127 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] fix megatron MOE int8 train bugs

See merge request dcutoolkit/deeplearing/TransformerEngine!37
parents 68487b2a 251dcc7e
...@@ -11,7 +11,7 @@ import transformer_engine_torch as tex ...@@ -11,7 +11,7 @@ import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched, w8a8_block_int8_matmul_wgrad_batched_native
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -285,7 +285,7 @@ def general_grouped_gemm( ...@@ -285,7 +285,7 @@ def general_grouped_gemm(
ref_scales_dout = [b._columnwise_scale_inv for b in B] ref_scales_dout = [b._columnwise_scale_inv for b in B]
ref_scales_x = [a._columnwise_scale_inv for a in A] ref_scales_x = [a._columnwise_scale_inv for a in A]
out = w8a8_block_int8_matmul_wgrad_batched( out = w8a8_block_int8_matmul_wgrad_batched_native(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
......
...@@ -462,14 +462,58 @@ def w8a8_block_int8_matmul_batched( ...@@ -462,14 +462,58 @@ def w8a8_block_int8_matmul_batched(
assert C.size(-1) == N assert C.size(-1) == N
config = { if best_config:
"BLOCK_SIZE_M": 64, config=best_config
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k, else:
"GROUP_SIZE_M": 8, #print("best config has not found!")
"num_warps": 4, # config = {
"num_stages": 1, # "BLOCK_SIZE_M": 32, #64
} # "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N":block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META): def grid(META):
return ( return (
......
...@@ -456,6 +456,18 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -456,6 +456,18 @@ def w8a8_block_int8_matmul_wgrad(
return C,config return C,config
def w8a8_block_int8_matmul_wgrad_batched_native(
A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None
):
for i in range(len(C_list)):
assert C_list[i] is not None
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
best_config=best_config
)
return C_list
def w8a8_block_int8_matmul_wgrad_batched( def w8a8_block_int8_matmul_wgrad_batched(
A_list, B_list, As_list, Bs_list, C_list, accumulate, A_list, B_list, As_list, Bs_list, C_list, accumulate,
...@@ -487,14 +499,58 @@ def w8a8_block_int8_matmul_wgrad_batched( ...@@ -487,14 +499,58 @@ def w8a8_block_int8_matmul_wgrad_batched(
else: else:
C = torch.stack(C_list).contiguous() C = torch.stack(C_list).contiguous()
config = { if best_config:
"BLOCK_SIZE_M": 64, config=best_config
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k, else:
"GROUP_SIZE_M": 8, #print("best config has not found!")
"num_warps": 4, # config = {
"num_stages": 1, # "BLOCK_SIZE_M": 32, #64
} # "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N":block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META): def grid(META):
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment