Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -6,10 +6,29 @@
#include "extensions.h"
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (int i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous();
......@@ -47,61 +66,57 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma, dbeta};
return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)};
}
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, float eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, const int sm_margin,
const bool zero_centered_gamma, const int scale_offset,
const int amax_offset, const int scale_inv_offset) {
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
const auto &input_ = input.contiguous();
auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset);
}
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps,
at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma,
const int scale_offset, const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
const auto &bias_ = bias.contiguous();
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
bias_tensor = makeTransformerEngineTensor(*bias);
}
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void *scale_dptr = getDataPtr(scale, scale_offset);
void *amax_dptr = getDataPtr(amax, amax_offset);
void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
size_t N = static_cast<size_t>(input_tensor.size(0));
size_t H = static_cast<size_t>(input_tensor.size(1));
std::vector<size_t> size = {N, H};
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input_);
auto gamma_cu = makeTransformerEngineTensor(weight_);
auto beta_cu = makeTransformerEngineTensor(bias_);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
scale_inv_dptr);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
TensorWrapper ln_out_tensor;
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization
NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
}
TensorWrapper mu_cu = makeTransformerEngineTensor(mu);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes
transformer_engine::TensorWrapper workspace;
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
......@@ -111,66 +126,30 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, float eps, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, const int sm_margin,
const bool zero_centered_gamma, const int scale_offset,
const int amax_offset, const int scale_inv_offset
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin,
zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, float eps, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
const auto &input_ = input.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, at::Tensor ln_out, float eps,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
DType itype = GetTransformerEngineDType(input.scalar_type());
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(),
at::Tensor(), itype, sm_margin, zero_centered_gamma);
return {ln_out, py::cast(mu), py::cast(rsigma)};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, float eps, const int sm_margin,
const bool zero_centered_gamma) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma);
return out[0];
}
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous();
......@@ -204,57 +183,48 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma};
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
float eps, at::Tensor scale, at::Tensor amax,
at::Tensor scale_inv, transformer_engine::DType otype,
const int sm_margin, const bool zero_centered_gamma,
const int scale_offset, const int amax_offset,
const int scale_inv_offset) {
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer,
transformer_engine::DType otype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset);
}
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight,
float eps, at::Tensor scale, at::Tensor ln_out,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin, const bool zero_centered_gamma,
const int scale_offset, const int amax_offset,
const int scale_inv_offset) {
using namespace transformer_engine;
auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void *scale_dptr = getDataPtr(scale, scale_offset);
void *amax_dptr = getDataPtr(amax, amax_offset);
void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
size_t N = static_cast<size_t>(input_tensor.shape().data[0]);
size_t H = static_cast<size_t>(input_tensor.shape().data[1]);
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
scale_inv_dptr);
std::vector<size_t> size = {N, H};
TensorWrapper ln_out_tensor;
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization
NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
}
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes
transformer_engine::TensorWrapper workspace;
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(),
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
......@@ -264,55 +234,22 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(),
workspace.data(),
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, rsigma};
}
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, const int sm_margin,
const bool zero_centered_gamma, const int scale_offset,
const int amax_offset, const int scale_inv_offset) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin,
zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
DType itype = GetTransformerEngineDType(input.scalar_type());
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
at::Tensor ln_out, float eps, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(),
at::Tensor(), itype, sm_margin, zero_centered_gamma);
}
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
const int sm_margin, const bool zero_centered_gamma) {
// This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
return out[0];
return {ln_out, py::none(), py::cast(rsigma)};
}
......@@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match.");
......
......@@ -11,6 +11,7 @@
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
const int num_tokens = input.size(0);
int num_cols = input.size(1);
const int topK = indices.size(1);
......@@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
using namespace transformer_engine::pytorch;
int num_cols = input.size(1);
// Activations type
......@@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob) {
using namespace transformer_engine::pytorch;
const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment