Commit 00fcd784 authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 4922108e
......@@ -488,31 +488,6 @@ def general_grouped_gemm(
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED."
......@@ -626,6 +601,32 @@ def general_grouped_gemm(
for i in range(num_gemms):
out[i].copy_(dw[i])
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
......
......@@ -949,7 +949,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm":
if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16):
if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment