Unverified Commit e113bf84 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] Fix wrong results for noncontiguous input (#1017)



* Ensure that the inputs to custom calls are contiguous
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes from review
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c8c05f38
...@@ -1816,3 +1816,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): ...@@ -1816,3 +1816,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
# should be bit-wise match # should be bit-wise match
for o, o_ref in zip(out, out_ref): for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def test_noncontiguous():
def _create2modules(m, params):
mod1 = m(*params)
mod2 = m(*params)
for p1, p2 in zip(mod1.parameters(), mod2.parameters()):
p2.data = p1.data.clone()
return mod1, mod2
def _run_module(m, inp):
out = m(inp)
out.sum().backward()
ret = [out]
if inp.grad is not None:
ret.append(inp.grad)
for p in m.parameters():
if p.requires_grad:
ret.append(p.grad)
return ret
a = torch.randn((128, 256), device="cuda", requires_grad=True)
a = a.T
assert not a.is_contiguous(), "The test is supposed to test noncontiguous input."
b = a.contiguous()
# LayerNorm
ln1, ln2 = _create2modules(LayerNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# RMSNorm
ln1, ln2 = _create2modules(RMSNorm, [128])
outT = _run_module(ln1, a)
out = _run_module(ln2, b)
assert_allclose(out, outT, 1e-7)
# GEMM
g1, g2 = _create2modules(Linear, [128, 128])
outT = _run_module(g1, a)
out = _run_module(g2, b)
assert_allclose(out, outT, 1e-7)
...@@ -54,6 +54,9 @@ def fp8_gemm( ...@@ -54,6 +54,9 @@ def fp8_gemm(
dtype=out_dtype, dtype=out_dtype,
device="cuda", device="cuda",
) )
else:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype bias_dtype = torch.bfloat16 if bias is None else bias.dtype
...@@ -202,6 +205,9 @@ def gemm( ...@@ -202,6 +205,9 @@ def gemm(
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
else:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
if gelu and not grad: if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype) gelu_input = torch.empty_like(out, dtype=dtype)
...@@ -311,7 +317,9 @@ def grouped_gemm( ...@@ -311,7 +317,9 @@ def grouped_gemm(
empty_tensors = [torch.Tensor()] * num_gemms empty_tensors = [torch.Tensor()] * num_gemms
if gelu and not grad: if gelu and not grad:
gelu_input = [torch.empty_like(o, dtype=dtype) for o in out] gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
]
elif not gelu: elif not gelu:
gelu_input = empty_tensors gelu_input = empty_tensors
...@@ -406,7 +414,10 @@ def fp8_grouped_gemm( ...@@ -406,7 +414,10 @@ def fp8_grouped_gemm(
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype
if gelu: if gelu:
gelu_input = [torch.empty_like(o, dtype=bias_dtype) for o in out] gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
]
else: else:
gelu_input = empty_tensors gelu_input = empty_tensors
bias_dtype = TE_DType[bias_dtype] bias_dtype = TE_DType[bias_dtype]
......
...@@ -21,6 +21,10 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType ...@@ -21,6 +21,10 @@ void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType
if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_();
return; return;
} }
A = A.contiguous();
B = B.contiguous();
auto te_A = makeTransformerEngineTensor( auto te_A = makeTransformerEngineTensor(
A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type, A.data_ptr(), {static_cast<size_t>(A.size(0)), static_cast<size_t>(A.size(1))}, A_type,
nullptr, nullptr, A_scale_inverse.data_ptr()); nullptr, nullptr, A_scale_inverse.data_ptr());
......
...@@ -10,16 +10,22 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -10,16 +10,22 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin, const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
auto dx = at::empty_like(x); const auto &dz_ = dz.contiguous();
auto dgamma = at::empty_like(gamma); const auto &x_ = x.contiguous();
auto dbeta = at::empty_like(gamma); const auto &mu_ = mu.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part;
auto dz_cu = makeTransformerEngineTensor(dz); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x); auto x_cu = makeTransformerEngineTensor(x_);
auto mu_cu = makeTransformerEngineTensor(mu); auto mu_cu = makeTransformerEngineTensor(mu_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma); auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma); auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx); auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dgamma_cu = makeTransformerEngineTensor(dgamma);
auto dbeta_cu = makeTransformerEngineTensor(dbeta); auto dbeta_cu = makeTransformerEngineTensor(dbeta);
...@@ -63,8 +69,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Ten ...@@ -63,8 +69,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Ten
const int amax_offset, const int scale_inv_offset) { const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); const auto &input_ = input.contiguous();
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset, sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset); scale_inv_offset);
} }
...@@ -76,6 +84,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc( ...@@ -76,6 +84,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
const int scale_offset, const int amax_offset, const int scale_inv_offset) { const int scale_offset, const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
const auto &bias_ = bias.contiguous();
// Choose kernel implementation // Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
...@@ -92,9 +104,9 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc( ...@@ -92,9 +104,9 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
DType itype = GetTransformerEngineDType(input.scalar_type()); DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input_);
auto gamma_cu = makeTransformerEngineTensor(weight); auto gamma_cu = makeTransformerEngineTensor(weight_);
auto beta_cu = makeTransformerEngineTensor(bias); auto beta_cu = makeTransformerEngineTensor(bias_);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
scale_inv_dptr); scale_inv_dptr);
auto mu_cu = makeTransformerEngineTensor(mu); auto mu_cu = makeTransformerEngineTensor(mu);
...@@ -145,9 +157,10 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor ...@@ -145,9 +157,10 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor
using namespace transformer_engine; using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type()); DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); const auto &input_ = input.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
} }
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
...@@ -174,14 +187,19 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, ...@@ -174,14 +187,19 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight,
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma, const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
auto dx = at::empty_like(x); const auto &dz_ = dz.contiguous();
auto dgamma = at::empty_like(gamma); const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part; transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
auto dz_cu = makeTransformerEngineTensor(dz); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x); auto x_cu = makeTransformerEngineTensor(x_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma); auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma); auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx); auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dgamma_cu = makeTransformerEngineTensor(dgamma);
...@@ -219,8 +237,11 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tenso ...@@ -219,8 +237,11 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tenso
const int scale_inv_offset) { const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); const auto &input_ = input.contiguous();
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, scale, ln_out, amax, scale_inv, otype, const auto &weight_ = weight.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset, sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset); scale_inv_offset);
} }
...@@ -295,10 +316,13 @@ std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &w ...@@ -295,10 +316,13 @@ std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &w
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine; using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
DType itype = GetTransformerEngineDType(input.scalar_type()); DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, sm_margin, zero_centered_gamma); return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma);
} }
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
......
...@@ -130,12 +130,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -130,12 +130,14 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
# First prepare LN output in higher precision, # First prepare LN output in higher precision,
# which will be later copied to a FP8 UB # which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat) ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format)
else: else:
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
)
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......
...@@ -149,7 +149,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -149,7 +149,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment