Commit 11bc1775 authored by yuguo's avatar yuguo
Browse files
parents e12a1085 059d92e2
...@@ -756,7 +756,7 @@ for i in range(b): ...@@ -756,7 +756,7 @@ for i in range(b):
tensorwise_dequantize(dy_scales[i], x_scales[i], dw_int32, dw_ref[i]) tensorwise_dequantize(dy_scales[i], x_scales[i], dw_int32, dw_ref[i])
else: else:
assert False assert False
dw_ref_tensor = torch.stack(dw_ref).contiguous() dw_ref_tensor = torch.stack(dw_ref).contiguous().view(-1, dw_ref[0].size(-1))
# print("dw_ref_tensor: ", dw_ref_tensor) # print("dw_ref_tensor: ", dw_ref_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -771,13 +771,13 @@ dw_tensor = torch.stack(dw).contiguous() ...@@ -771,13 +771,13 @@ dw_tensor = torch.stack(dw).contiguous()
out_dtype = torch.bfloat16 out_dtype = torch.bfloat16
dw_tensor = tex.tensorwise_int8_batchgemm( dw_tensor = tex.tensorwise_int8_batchgemm(
x_int8_tensor.view(-1, x_int8.size(-1)), x_int8_tensor.view(-1, x_int8_tensor.size(-1)),
transa, transa,
dy_int8_tensor.view(-1, dy_int8.size(-1)), dy_int8_tensor.view(-1, dy_int8_tensor.size(-1)),
transb, transb,
x_scales_tensor, x_scales_tensor,
dy_scales_tensor, dy_scales_tensor,
dw_tensor.view(-1, dw.size(-1)), dw_tensor.view(-1, dw_tensor.size(-1)),
b, b,
out_quantizer, out_quantizer,
TE_DType[out_dtype], TE_DType[out_dtype],
......
...@@ -499,11 +499,11 @@ def general_grouped_gemm( ...@@ -499,11 +499,11 @@ def general_grouped_gemm(
out[0], out[0],
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], out_dtype,
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
...@@ -534,11 +534,11 @@ def general_grouped_gemm( ...@@ -534,11 +534,11 @@ def general_grouped_gemm(
out[0], out[0],
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], out_dtype,
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
...@@ -571,10 +571,10 @@ def general_grouped_gemm( ...@@ -571,10 +571,10 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[out_dtype], TE_DType[out_dtype],
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
...@@ -623,10 +623,10 @@ def general_grouped_gemm( ...@@ -623,10 +623,10 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
...@@ -671,10 +671,10 @@ def general_grouped_gemm( ...@@ -671,10 +671,10 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
...@@ -718,10 +718,10 @@ def general_grouped_gemm( ...@@ -718,10 +718,10 @@ def general_grouped_gemm(
num_gemms, num_gemms,
None, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias[0], None,
bias_dtype, bias_dtype,
gelu, gelu,
gelu_input[0], None,
grad, # grad grad, # grad
workspaces[0], workspaces[0],
workspaces[0].shape[0], workspaces[0].shape[0],
......
...@@ -937,7 +937,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -937,7 +937,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
if enable_lightop: if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight) dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment