Commit 98de2cdd authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] fix blockwise int8 train issues in megatron

See merge request dcutoolkit/deeplearing/TransformerEngine!30
parents 5b82e699 ecdd8251
...@@ -82,8 +82,6 @@ def general_gemm( ...@@ -82,8 +82,6 @@ def general_gemm(
if accumulate: if accumulate:
assert out is not None assert out is not None
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
elif layout == "NN": elif layout == "NN":
...@@ -103,8 +101,6 @@ def general_gemm( ...@@ -103,8 +101,6 @@ def general_gemm(
if accumulate: if accumulate:
assert out is not None assert out is not None
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
elif layout == "NT": elif layout == "NT":
...@@ -124,8 +120,6 @@ def general_gemm( ...@@ -124,8 +120,6 @@ def general_gemm(
if accumulate: if accumulate:
assert out is not None assert out is not None
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
else: else:
...@@ -234,9 +228,8 @@ def general_grouped_gemm( ...@@ -234,9 +228,8 @@ def general_grouped_gemm(
) )
if accumulate: if accumulate:
assert out is not None assert out is not None
out = torch.stack(out).contiguous()
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
elif layout == "NN": elif layout == "NN":
...@@ -255,9 +248,8 @@ def general_grouped_gemm( ...@@ -255,9 +248,8 @@ def general_grouped_gemm(
) )
if accumulate: if accumulate:
assert out is not None assert out is not None
out = torch.stack(out).contiguous()
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
elif layout == "NT": elif layout == "NT":
...@@ -276,9 +268,8 @@ def general_grouped_gemm( ...@@ -276,9 +268,8 @@ def general_grouped_gemm(
) )
if accumulate: if accumulate:
assert out is not None assert out is not None
out = torch.stack(out).contiguous()
y = y + out y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
else: else:
......
...@@ -324,6 +324,8 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -324,6 +324,8 @@ def w8a8_block_int8_matmul_wgrad(
""" """
assert len(block_size) == 2 assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
B = B.view(B.size(0), -1)
assert A.ndim == 2
assert A.shape[-1] == B.shape[-1] assert A.shape[-1] == B.shape[-1]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}") # print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
...@@ -451,6 +453,9 @@ def w8a8_block_int8_matmul_wgrad_batched( ...@@ -451,6 +453,9 @@ def w8a8_block_int8_matmul_wgrad_batched(
As = torch.stack(As_list).contiguous() As = torch.stack(As_list).contiguous()
Bs = torch.stack(Bs_list).contiguous() Bs = torch.stack(Bs_list).contiguous()
B_new_shape = B.size()[:2] + (-1,)
B = B.view(*B_new_shape)
assert A.ndim == 3
assert A.shape[-1] == B.shape[-1] assert A.shape[-1] == B.shape[-1]
M = A.numel() // A.shape[-1] // A.shape[0] M = A.numel() // A.shape[-1] // A.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment