Unverified Commit a435ec01 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Remove implicit padding and unpadding in `GroupedLinear` (#984)



* remove implicit padding and unpadding
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent e3bb24e5
...@@ -71,14 +71,6 @@ _GEMM_OUTPUT = 0 ...@@ -71,14 +71,6 @@ _GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0 _GRAD_OUTPUT = 0
def _pad_tensor(inp: torch.Tensor):
if inp.shape[0] % 16 == 0:
return inp
pad_len = (inp.shape[0] + 15) // 16 * 16 - inp.shape[0]
pad_tensor = torch.zeros(pad_len, inp.shape[1], dtype=inp.dtype, device=inp.device)
return torch.cat((inp, pad_tensor), dim=0)
class _GroupedLinear(torch.autograd.Function): class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear semi-top level module """GroupedLinear semi-top level module
Calls custom cuda extensions. Calls custom cuda extensions.
...@@ -116,7 +108,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -116,7 +108,6 @@ class _GroupedLinear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits) inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8: if fp8:
inputmats = [_pad_tensor(mat) for mat in inputmats]
for i in range(num_gemms): for i in range(num_gemms):
assert_dim_for_fp8_exec(inputmats[i]) assert_dim_for_fp8_exec(inputmats[i])
assert_dim_for_fp8_exec(weights[i]) assert_dim_for_fp8_exec(weights[i])
...@@ -170,14 +161,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -170,14 +161,11 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8 = weights weights_fp8 = weights
assert all(isinstance(w, Float8Tensor) for w in weights_fp8) assert all(isinstance(w, Float8Tensor) for w in weights_fp8)
out_list = [ out = torch.empty(
torch.empty( [sum(m_splits), weights_fp8[0].size(0)],
[inputmats[i].size(0), weights_fp8[0].size(0)], dtype=activation_dtype,
dtype=activation_dtype, device=inputmats[0].device,
device=inputmats[i].device, )
)
for i in range(num_gemms)
]
_ = fp8_grouped_gemm( _ = fp8_grouped_gemm(
[w._data for w in weights_fp8], [w._data for w in weights_fp8],
...@@ -188,15 +176,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -188,15 +176,13 @@ class _GroupedLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
_GEMM_INPUT, _GEMM_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
out_list, torch.split(out, m_splits),
activation_dtype, activation_dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
) )
# unpad the output
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
else: else:
logger.debug("Running forward in %s", activation_dtype) logger.debug("Running forward in %s", activation_dtype)
...@@ -333,7 +319,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -333,7 +319,6 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
grad_output_mats = [_pad_tensor(mat) for mat in grad_output_mats]
if ctx.use_bias: if ctx.use_bias:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i], grad_output_c[i], grad_output_t[i] = ( grad_biases[i], grad_output_c[i], grad_output_t[i] = (
...@@ -372,14 +357,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -372,14 +357,11 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
logger.debug("Running backward in FP8") logger.debug("Running backward in FP8")
dgrad_list = [ dgrad = torch.empty(
torch.empty( (sum(ctx.m_splits), weights_fp8[i].size(1)),
(grad_output_c[i].size(0), weights_fp8[i].size(1)), dtype=ctx.activation_dtype,
dtype=ctx.activation_dtype, device=grad_output.device,
device=grad_output.device, )
)
for i in range(ctx.num_gemms)
]
fp8_grouped_gemm( fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8], [w.transpose_2d() for w in weights_fp8],
torch.cat( torch.cat(
...@@ -389,17 +371,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -389,17 +371,13 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8[0]._fp8_dtype, weights_fp8[0]._fp8_dtype,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
0, _GRAD_OUTPUT,
fp8_dtype_backward, fp8_dtype_backward,
dgrad_list, torch.split(dgrad, ctx.m_splits),
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
# unpad the output
dgrad = torch.cat(
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
)
else: else:
logger.debug("Running backward in %s", ctx.activation_dtype) logger.debug("Running backward in %s", ctx.activation_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment