Commit 11bc1775 authored by yuguo's avatar yuguo
Browse files
parents e12a1085 059d92e2
......@@ -756,7 +756,7 @@ for i in range(b):
tensorwise_dequantize(dy_scales[i], x_scales[i], dw_int32, dw_ref[i])
else:
assert False
dw_ref_tensor = torch.stack(dw_ref).contiguous()
dw_ref_tensor = torch.stack(dw_ref).contiguous().view(-1, dw_ref[0].size(-1))
# print("dw_ref_tensor: ", dw_ref_tensor)
torch.cuda.synchronize()
......@@ -771,13 +771,13 @@ dw_tensor = torch.stack(dw).contiguous()
out_dtype = torch.bfloat16
dw_tensor = tex.tensorwise_int8_batchgemm(
x_int8_tensor.view(-1, x_int8.size(-1)),
x_int8_tensor.view(-1, x_int8_tensor.size(-1)),
transa,
dy_int8_tensor.view(-1, dy_int8.size(-1)),
dy_int8_tensor.view(-1, dy_int8_tensor.size(-1)),
transb,
x_scales_tensor,
dy_scales_tensor,
dw_tensor.view(-1, dw.size(-1)),
dw_tensor.view(-1, dw_tensor.size(-1)),
b,
out_quantizer,
TE_DType[out_dtype],
......
......@@ -499,11 +499,11 @@ def general_grouped_gemm(
out[0],
num_gemms,
None,
TE_DType[out_dtype],
bias[0],
out_dtype,
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......@@ -534,11 +534,11 @@ def general_grouped_gemm(
out[0],
num_gemms,
None,
TE_DType[out_dtype],
bias[0],
out_dtype,
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......@@ -571,10 +571,10 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[out_dtype],
bias[0],
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......@@ -623,10 +623,10 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......@@ -671,10 +671,10 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......@@ -718,10 +718,10 @@ def general_grouped_gemm(
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
None,
bias_dtype,
gelu,
gelu_input[0],
None,
grad, # grad
workspaces[0],
workspaces[0].shape[0],
......
......@@ -937,7 +937,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm":
if enable_lightop:
if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment