Unverified Commit e113bf84 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] Fix wrong results for noncontiguous input (#1017)



* Ensure that the inputs to custom calls are contiguous
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes from review
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c8c05f38
......@@ -1816,3 +1816,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def test_noncontiguous():
def _create2modules(m, params):
mod1 = m(*params)
mod2 = m(*params)
for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
p2.data = p1.data.clone()
return mod1, mod2
def _run_module(m, inp):
out = m(inp)
out.sum().backward()
ret = [out]
if inp.grad is not None:
ret.append(inp.grad)
for p in m.parameters():
if p.requires_grad:
ret.append(p.grad)
return ret
a = torch.randn((128, 256), device="cuda", requires_grad=True)
a = a.T
assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."
b = a.contiguous()
# LayerNorm
ln1, ln2 = _create2modules(LayerNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# RMSNorm
ln1, ln2 = _create2modules(RMSNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# GEMM
g1, g2 = _create2modules(Linear, [128, 128])
outT = _run_module(g1, a)
out = _run_module(g2, b)
assert_allclose(out, outT, 1e-7)
......@@ -54,6 +54,9 @@ def fp8_gemm(
dtype=out_dtype,
device="cuda",
)
else:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
......@@ -202,6 +205,9 @@ def gemm(
dtype=dtype,
device="cuda",
)
else:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
......@@ -311,7 +317,9 @@ def grouped_gemm(
empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad:
gelu_input = [torch.empty_like(o, dtype=dtype) for o in out]
gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
]
elif not gelu:
gelu_input = empty_tensors
......@@ -406,7 +414,10 @@ def fp8_grouped_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu:
gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out]
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else:
gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype]
......
......@@ -21,6 +21,10 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return;
}
A = A.contiguous();
B = B.contiguous();
auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr());
......
......@@ -10,16 +10,22 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
auto dbeta = at::empty_like(gamma);
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
auto mu_cu = makeTransformerEngineTensor(mu_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
......@@ -63,8 +69,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Ten
const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
const auto &input_ = input.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset);
}
......@@ -76,6 +84,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
const int scale_offset, const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
const auto &bias_ = bias.contiguous();
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
......@@ -92,9 +104,9 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto input_cu = makeTransformerEngineTensor(input_);
auto gamma_cu = makeTransformerEngineTensor(weight_);
auto beta_cu = makeTransformerEngineTensor(bias_);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
scale_inv_dptr);
auto mu_cu = makeTransformerEngineTensor(mu);
......@@ -145,9 +157,10 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
const auto &input_ = input.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
......@@ -174,14 +187,19 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight,
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
auto dz_cu = makeTransformerEngineTensor(dz);
auto x_cu = makeTransformerEngineTensor(x);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
auto gamma_cu = makeTransformerEngineTensor(gamma);
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
......@@ -219,8 +237,11 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tenso
const int scale_inv_offset) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, scale, ln_out, amax, scale_inv, otype,
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset);
}
......@@ -295,10 +316,13 @@ std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &w
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, sm_margin, zero_centered_gamma);
return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
......
......@@ -130,12 +130,14 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output:
# First prepare LN output in higher precision,
# which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat)
ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format)
else:
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
ln_out = torch.empty_like(
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
)
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......
......@@ -149,7 +149,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
ln_out = torch.empty_like(
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment