Commit ecdd8251 authored by yuguo's avatar yuguo
Browse files

[DCU] fix blockwise int8 train issues in megatron

parent 7f946529
......@@ -82,8 +82,6 @@ def general_gemm(
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
elif layout == "NN":
......@@ -103,8 +101,6 @@ def general_gemm(
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
elif layout == "NT":
......@@ -124,8 +120,6 @@ def general_gemm(
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
else:
......@@ -234,9 +228,8 @@ def general_grouped_gemm(
)
if accumulate:
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
elif layout == "NN":
......@@ -255,9 +248,8 @@ def general_grouped_gemm(
)
if accumulate:
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
elif layout == "NT":
......@@ -276,9 +268,8 @@ def general_grouped_gemm(
)
if accumulate:
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
else:
......
......@@ -324,6 +324,8 @@ def w8a8_block_int8_matmul_wgrad(
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
B = B.view(B.size(0), -1)
assert A.ndim == 2
assert A.shape[-1] == B.shape[-1]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
......@@ -451,6 +453,9 @@ def w8a8_block_int8_matmul_wgrad_batched(
As = torch.stack(As_list).contiguous()
Bs = torch.stack(Bs_list).contiguous()
B_new_shape = B.size()[:2] + (-1,)
B = B.view(*B_new_shape)
assert A.ndim == 3
assert A.shape[-1] == B.shape[-1]
M = A.numel() // A.shape[-1] // A.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment