Commit 00fcd784 authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 4922108e
...@@ -488,31 +488,6 @@ def general_grouped_gemm( ...@@ -488,31 +488,6 @@ def general_grouped_gemm(
else: else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched: if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED." assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED."
...@@ -628,6 +603,32 @@ def general_grouped_gemm( ...@@ -628,6 +603,32 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)): if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
......
...@@ -949,7 +949,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -949,7 +949,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16): if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight) dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment