Unverified Commit a435ec01 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Remove implicit padding and unpadding in `GroupedLinear` (#984)



* remove implicit padding and unpadding
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent e3bb24e5
......@@ -71,14 +71,6 @@ _GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0
def _pad_tensor(inp: torch.Tensor):
if inp.shape[0] % 16 == 0:
return inp
pad_len = (inp.shape[0] + 15) // 16 * 16 - inp.shape[0]
pad_tensor = torch.zeros(pad_len, inp.shape[1], dtype=inp.dtype, device=inp.device)
return torch.cat((inp, pad_tensor), dim=0)
class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear semi-top level module
Calls custom cuda extensions.
......@@ -116,7 +108,6 @@ class _GroupedLinear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
inputmats = [_pad_tensor(mat) for mat in inputmats]
for i in range(num_gemms):
assert_dim_for_fp8_exec(inputmats[i])
assert_dim_for_fp8_exec(weights[i])
......@@ -170,14 +161,11 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8 = weights
assert all(isinstance(w, Float8Tensor) for w in weights_fp8)
out_list = [
torch.empty(
[inputmats[i].size(0), weights_fp8[0].size(0)],
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)],
dtype=activation_dtype,
device=inputmats[i].device,
device=inputmats[0].device,
)
for i in range(num_gemms)
]
_ = fp8_grouped_gemm(
[w._data for w in weights_fp8],
......@@ -188,15 +176,13 @@ class _GroupedLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv,
_GEMM_INPUT,
fp8_dtype_forward,
out_list,
torch.split(out, m_splits),
activation_dtype,
get_multi_stream_cublas_workspace(),
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
# unpad the output
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
else:
logger.debug("Running forward in %s", activation_dtype)
......@@ -333,7 +319,6 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
grad_output_mats = [_pad_tensor(mat) for mat in grad_output_mats]
if ctx.use_bias:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output_c[i], grad_output_t[i] = (
......@@ -372,14 +357,11 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
logger.debug("Running backward in FP8")
dgrad_list = [
torch.empty(
(grad_output_c[i].size(0), weights_fp8[i].size(1)),
dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype,
device=grad_output.device,
)
for i in range(ctx.num_gemms)
]
fp8_grouped_gemm(
[w.transpose_2d() for w in weights_fp8],
torch.cat(
......@@ -389,17 +371,13 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8[0]._fp8_dtype,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
0,
_GRAD_OUTPUT,
fp8_dtype_backward,
dgrad_list,
torch.split(dgrad, ctx.m_splits),
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# unpad the output
dgrad = torch.cat(
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment