Unverified Commit cfbbfb89 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Cleanup pytorch extensions (#1781)



* rm unused swizzle extensions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix swizzle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Consistent namespaces and first refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* transformer_engine
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert accidental perm change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f966d5f7
......@@ -19,11 +19,7 @@ def setup_pytorch_extension(
"""Setup CUDA extension for PyTorch support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
] + all_files_in_dir(extensions_dir)
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
# Header files
include_dirs = get_cuda_include_dirs()
......
......@@ -56,7 +56,7 @@ def all_files_in_dir(path, name_extension=None):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
if name_extension is not None and name_extension not in name:
if name_extension is not None and not name.endswith(f".{name_extension}"):
continue
all_files.append(Path(dirname, name))
return all_files
......
......@@ -9,6 +9,7 @@
#include "c10/util/ArrayRef.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
std::vector<size_t> getTensorShape(at::Tensor t) {
......
......@@ -11,32 +11,31 @@
#include "common.h"
namespace transformer_engine::pytorch {
/***************************************************************************************************
* Permutation
**************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK);
at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK);
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK);
at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK);
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob);
const DType dtype, at::Tensor row_id_map,
at::Tensor prob);
/***************************************************************************************************
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups,
......@@ -61,8 +60,8 @@ std::vector<py::object> fused_attn_bwd(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer);
......@@ -83,25 +82,29 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
using MaybeTensor = std::optional<at::Tensor>;
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter);
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
namespace transformer_engine::pytorch {
std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count);
/***************************************************************************************************
* Transpose
......@@ -109,16 +112,11 @@ namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype);
std::vector<py::handle> quantizer_list, DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch {
/***************************************************************************************************
* Activations
**************************************************************************************************/
......@@ -155,8 +153,6 @@ py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle qu
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
} // namespace transformer_engine::pytorch
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
......@@ -168,7 +164,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma);
/***************************************************************************************************
......@@ -180,35 +176,24 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const int sm_margin, const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer,
transformer_engine::DType otype, const int sm_margin,
const bool zero_centered_gamma);
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);
/***************************************************************************************************
* Cast
**************************************************************************************************/
namespace transformer_engine::pytorch {
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop);
py::object dequantize(const py::handle &input, transformer_engine::DType otype);
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);
std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
py::object dequantize(const py::handle &input, DType otype);
/***************************************************************************************************
* Cast fusions
* Bias gradient fusions
**************************************************************************************************/
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
......@@ -224,8 +209,6 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
} // namespace transformer_engine::pytorch
/***************************************************************************************************
* Softmax
**************************************************************************************************/
......@@ -262,7 +245,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype, float margin);
DType fp8_dtype, float margin);
// Note that the start_offset is the logical offset along the tensor dimension.
// The offset in bytes is start_offset * sizeof(tensor.dtype)
......@@ -271,7 +254,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype);
const DType out_dtype);
/***************************************************************************************************
* Rotary positional embedding
......@@ -335,7 +318,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python);
using transformer_engine::DType;
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
......@@ -389,7 +371,6 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
* NVSHMEM APIs
**************************************************************************************************/
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
......@@ -399,15 +380,8 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);
} // namespace transformer_engine::pytorch
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
......
......@@ -4,15 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include "common.h"
#include "extensions.h"
namespace transformer_engine::pytorch {
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
......@@ -27,7 +28,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
......@@ -92,7 +93,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
......@@ -105,7 +106,6 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
......@@ -184,7 +184,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
......@@ -192,3 +192,5 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads;
}
} // namespace transformer_engine::pytorch
......@@ -4,24 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include "common.h"
#include "extensions.h"
#include "pybind.h"
constexpr int block_size = 512;
namespace {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
constexpr int block_size = 512;
// fast zero-fills of tensors
void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) {
......@@ -62,6 +51,23 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe
return philox_args;
}
} // namespace
namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
// fused attention FWD with separate Q, K and V tensors
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
......@@ -74,8 +80,6 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none();
......@@ -87,8 +91,8 @@ std::vector<py::object> fused_attn_fwd(
te_V = makeTransformerEngineTensor(V, none);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const transformer_engine::DType qkv_type = te_Q.dtype();
const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
......@@ -260,13 +264,11 @@ std::vector<py::object> fused_attn_bwd(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto none = py::none();
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
te_Q = makeTransformerEngineTensor(Q, none);
......@@ -276,8 +278,8 @@ std::vector<py::object> fused_attn_bwd(
te_dO = makeTransformerEngineTensor(dO, none);
// qkv type from the te_Q
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
const transformer_engine::DType qkv_type = te_Q.dtype();
const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
......@@ -497,9 +499,6 @@ std::vector<py::object> fused_attn_bwd(
}
at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half ||
qkvi.scalar_type() == at::ScalarType::BFloat16);
......@@ -521,9 +520,6 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
}
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(q.is_contiguous());
NVTE_CHECK(k.is_contiguous());
NVTE_CHECK(v.is_contiguous());
......@@ -556,9 +552,6 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
......@@ -600,9 +593,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
......@@ -647,9 +637,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed, int second_half_lse_seqlen) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
......@@ -700,9 +687,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_out = makeTransformerEngineTensor(out);
auto te_out_per_step = makeTransformerEngineTensor(out_per_step);
auto te_lse = makeTransformerEngineTensor(lse);
......@@ -720,9 +704,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_grad = makeTransformerEngineTensor(grad);
auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
......@@ -737,9 +718,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2);
......@@ -766,9 +744,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
**************************************************************************************************/
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int h = tensor.size(1);
int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d};
......@@ -789,9 +764,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
......@@ -812,9 +784,6 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(),
......@@ -836,3 +805,5 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged,
at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -6,11 +6,10 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
......@@ -29,9 +28,6 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
......@@ -51,3 +47,5 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -20,12 +20,12 @@
namespace {
void* get_data_ptr(MaybeTensor tensor) {
void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) {
if (tensor.has_value()) return tensor->data_ptr();
return nullptr;
}
size_t get_size(MaybeTensor tensor, int dim) {
size_t get_size(transformer_engine::pytorch::MaybeTensor tensor, int dim) {
if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim));
return 0;
}
......@@ -271,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out;
}
} // namespace transformer_engine::pytorch
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
......@@ -326,13 +320,10 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
......@@ -450,3 +441,5 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
});
return bias;
}
} // namespace transformer_engine::pytorch
......@@ -6,6 +6,10 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); }
} // namespace transformer_engine::pytorch
......@@ -6,14 +6,13 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -29,9 +28,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -48,9 +44,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -68,9 +61,6 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -91,9 +81,6 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -107,3 +94,5 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -6,12 +6,11 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -21,3 +20,5 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -6,12 +6,11 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
......@@ -57,9 +56,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
......@@ -105,3 +101,5 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
} // namespace transformer_engine::pytorch
......@@ -6,11 +6,10 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -19,3 +18,5 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -6,13 +6,12 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
......@@ -22,3 +21,5 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -9,28 +9,11 @@
#include "pybind.h"
namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous();
......@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -80,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
......@@ -135,7 +116,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
......@@ -202,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous();
......@@ -210,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
......@@ -244,12 +224,9 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
}
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
py::object out, py::handle quantizer, DType out_dtype,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
......@@ -297,7 +274,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size
transformer_engine::TensorWrapper workspace;
TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
......@@ -360,3 +337,5 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
return {out, py::none(), py::cast(rsigma)};
}
} // namespace transformer_engine::pytorch
......@@ -17,7 +17,8 @@
#include <torch/cuda.h>
#include <torch/extension.h>
namespace nvshmem_api {
namespace transformer_engine::pytorch {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
......@@ -126,4 +127,5 @@ void nvshmem_finalize() {
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
} // namespace transformer_engine::pytorch
......@@ -7,12 +7,11 @@
#include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
......@@ -22,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list;
std::vector<DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
......@@ -52,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
std::vector<TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor {
DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data();
};
......@@ -81,3 +80,5 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
});
}
} // namespace transformer_engine::pytorch
......@@ -6,10 +6,11 @@
#include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
const int num_tokens = input.size(0);
int num_cols = input.size(1);
const int topK = indices.size(1);
......@@ -71,28 +72,23 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
transformer_engine::DType::kInt32);
DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
row_id_map_cu.data(), transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(),
transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols,
num_out_tokens, stream);
row_id_map_cu.data(), TensorWrapper().data(), TensorWrapper().data(),
TensorWrapper().data(), num_tokens, topK, num_cols, num_out_tokens, stream);
return std::make_tuple(permuted_output, row_id_map, workspace);
}
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK) {
return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK);
}
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) {
using namespace transformer_engine::pytorch;
at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor prob, int64_t num_tokens, int64_t topK) {
int num_cols = input.size(1);
// Output buffer alloc
......@@ -121,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
}
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob) {
using namespace transformer_engine::pytorch;
const DType dtype, at::Tensor row_id_map,
at::Tensor prob) {
const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
......@@ -153,9 +148,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto prob_cu = makeTransformerEngineTensor(prob);
auto prob_grad_cu = makeTransformerEngineTensor(prob_grad);
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(),
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), TensorWrapper().data(),
row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(),
num_tokens, topK, num_cols, 0, stream);
return std::make_tuple(act_grad, prob_grad);
}
} // namespace transformer_engine::pytorch
......@@ -110,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......@@ -159,170 +155,188 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer"));
// Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd, "MOE permute FWD",
m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_permute_bwd", moe_permute_bwd, "MOE permute BWD",
m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_fwd", moe_unpermute_fwd, "MOE unpermute FWD",
m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD",
py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_bwd", moe_unpermute_bwd, "MOE unpermute BWD",
m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>());
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward,
m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
"Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
"Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward",
&transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_backward",
&transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward,
m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward,
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward,
&transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
// Other granular functions
m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"),
py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"),
py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"),
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"),
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute absolute max value in tensor", py::arg("input"),
py::arg("amax"), py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
m.def("fused_amax_and_scale_update_after_reduction",
&transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax", &fp8_block_scaling_compute_partial_amax,
m.def("fp8_block_scaling_compute_partial_amax",
&transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast", &fp8_block_scaling_partial_cast,
m.def("fp8_block_scaling_partial_cast",
&transformer_engine::pytorch::fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
"Prepare QKV for Flash Attention", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd,
"Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache",
py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD",
py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD",
py::call_guard<py::gil_scoped_release>());
m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache,
"Copy new KV tokens to KV cache", py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd,
"Convert a tensor from THD to BSHD", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd,
"Convert a tesnor from BSHD to THD", py::call_guard<py::gil_scoped_release>());
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward,
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor",
py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction,
m.def("thd_second_half_lse_correction",
&transformer_engine::pytorch::thd_second_half_lse_correction,
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse,
m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse,
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction,
m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction,
"Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction,
m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
m.def("multi_tensor_unscale_l2norm",
&transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
m.def("multi_tensor_adam_param_remainder",
&transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
m.def("multi_tensor_adam_capturable",
&transformer_engine::pytorch::multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
m.def("multi_tensor_adam_capturable_master",
&transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda,
m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Data structures
......
......@@ -12,10 +12,9 @@
#include "common/common.h"
#include "extensions.h"
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
namespace transformer_engine::pytorch {
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
......@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
TensorWrapper fake_te_output(
nullptr, te_input.shape(),
transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>());
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
......@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
......@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment