Unverified Commit cfbbfb89 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Cleanup pytorch extensions (#1781)



* rm unused swizzle extensions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix swizzle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Consistent namespaces and first refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format and lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* transformer_engine
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert accidental perm change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f966d5f7
...@@ -19,11 +19,7 @@ def setup_pytorch_extension( ...@@ -19,11 +19,7 @@ def setup_pytorch_extension(
"""Setup CUDA extension for PyTorch support""" """Setup CUDA extension for PyTorch support"""
# Source files # Source files
csrc_source_files = Path(csrc_source_files) sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
] + all_files_in_dir(extensions_dir)
# Header files # Header files
include_dirs = get_cuda_include_dirs() include_dirs = get_cuda_include_dirs()
......
...@@ -56,7 +56,7 @@ def all_files_in_dir(path, name_extension=None): ...@@ -56,7 +56,7 @@ def all_files_in_dir(path, name_extension=None):
all_files = [] all_files = []
for dirname, _, names in os.walk(path): for dirname, _, names in os.walk(path):
for name in names: for name in names:
if name_extension is not None and name_extension not in name: if name_extension is not None and not name.endswith(f".{name_extension}"):
continue continue
all_files.append(Path(dirname, name)) all_files.append(Path(dirname, name))
return all_files return all_files
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "c10/util/ArrayRef.h" #include "c10/util/ArrayRef.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
std::vector<size_t> getTensorShape(at::Tensor t) { std::vector<size_t> getTensorShape(at::Tensor t) {
......
...@@ -11,32 +11,31 @@ ...@@ -11,32 +11,31 @@
#include "common.h" #include "common.h"
namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
* Permutation * Permutation
**************************************************************************************************/ **************************************************************************************************/
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num); std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK);
int64_t topK);
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK);
int64_t topK);
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob); at::Tensor prob);
/*************************************************************************************************** /***************************************************************************************************
* Attention * Attention
**************************************************************************************************/ **************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, NVTE_Mask_Type attn_mask_type, float p_dropout,
size_t num_attn_heads, size_t num_gqa_groups, size_t num_attn_heads, size_t num_gqa_groups,
...@@ -61,8 +60,8 @@ std::vector<py::object> fused_attn_bwd( ...@@ -61,8 +60,8 @@ std::vector<py::object> fused_attn_bwd(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer); py::handle dp_quantizer, py::handle dqkv_quantizer);
...@@ -83,25 +82,29 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at ...@@ -83,25 +82,29 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
using MaybeTensor = std::optional<at::Tensor>; using MaybeTensor = std::optional<at::Tensor>;
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type, at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D, bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter); bool gemm_producer, at::Tensor counter);
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, std::vector<at::Tensor> bias, DType bias_type, bool single_output,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count);
bool use_split_accumulator, int math_sm_count);
namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
...@@ -109,16 +112,11 @@ namespace transformer_engine::pytorch { ...@@ -109,16 +112,11 @@ namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list, std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list, std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list, DType otype);
transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
...@@ -155,8 +153,6 @@ py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle qu ...@@ -155,8 +153,6 @@ py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle qu
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
} // namespace transformer_engine::pytorch
/*************************************************************************************************** /***************************************************************************************************
* LayerNorm * LayerNorm
**************************************************************************************************/ **************************************************************************************************/
...@@ -168,7 +164,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -168,7 +164,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer, float eps, py::object ln_out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin, DType out_dtype, const int sm_margin,
const bool zero_centered_gamma); const bool zero_centered_gamma);
/*************************************************************************************************** /***************************************************************************************************
...@@ -180,35 +176,24 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -180,35 +176,24 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const int sm_margin, const bool zero_centered_gamma); const int sm_margin, const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer, py::object ln_out, py::handle quantizer, DType otype,
transformer_engine::DType otype, const int sm_margin, const int sm_margin, const bool zero_centered_gamma);
const bool zero_centered_gamma);
/*************************************************************************************************** /***************************************************************************************************
* Cast * Cast
**************************************************************************************************/ **************************************************************************************************/
namespace transformer_engine::pytorch {
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop); std::optional<at::Tensor> noop);
py::object dequantize(const py::handle &input, transformer_engine::DType otype); py::object dequantize(const py::handle &input, DType otype);
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);
std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
/*************************************************************************************************** /***************************************************************************************************
* Cast fusions * Bias gradient fusions
**************************************************************************************************/ **************************************************************************************************/
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer); py::handle quantizer);
...@@ -224,8 +209,6 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te ...@@ -224,8 +209,6 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer); py::handle quantizer);
} // namespace transformer_engine::pytorch
/*************************************************************************************************** /***************************************************************************************************
* Softmax * Softmax
**************************************************************************************************/ **************************************************************************************************/
...@@ -262,7 +245,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio ...@@ -262,7 +245,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
std::vector<at::Tensor> amax_histories, std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales, std::vector<at::Tensor> scales,
const std::string &amax_compute_algo, const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype, float margin); DType fp8_dtype, float margin);
// Note that the start_offset is the logical offset along the tensor dimension. // Note that the start_offset is the logical offset along the tensor dimension.
// The offset in bytes is start_offset * sizeof(tensor.dtype) // The offset in bytes is start_offset * sizeof(tensor.dtype)
...@@ -271,7 +254,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor ...@@ -271,7 +254,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len, size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype); const DType out_dtype);
/*************************************************************************************************** /***************************************************************************************************
* Rotary positional embedding * Rotary positional embedding
...@@ -335,7 +318,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -335,7 +318,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python); at::Tensor inv_scale, at::optional<bool> per_tensor_python);
using transformer_engine::DType;
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
...@@ -389,7 +371,6 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -389,7 +371,6 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
* NVSHMEM APIs * NVSHMEM APIs
**************************************************************************************************/ **************************************************************************************************/
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group); void init_nvshmem_backend(c10d::ProcessGroup *process_group);
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype); at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
...@@ -399,15 +380,8 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at ...@@ -399,15 +380,8 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize(); void nvshmem_finalize();
} // namespace nvshmem_api
/*************************************************************************************************** } // namespace transformer_engine::pytorch
* swizzle
**************************************************************************************************/
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);
/*************************************************************************************************** /***************************************************************************************************
* Comm+GEMM Overlap Wrappers * Comm+GEMM Overlap Wrappers
......
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common.h"
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions, const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) { const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
...@@ -27,7 +28,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -27,7 +28,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) { if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value()); start_positions_cu = makeTransformerEngineTensor(start_positions.value());
} }
...@@ -92,7 +93,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -92,7 +93,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
"expected the last dim of the input tensor equals or is " "expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
start_positions_cu.data(), output_cu.data(), qkv_format, interleaved, start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
...@@ -105,7 +106,6 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -105,7 +106,6 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) { const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
...@@ -184,7 +184,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -184,7 +184,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
"expected the last dim of the output_grads tensor equals or is " "expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d, h, d, d2, stride_s, stride_b, stride_h, stride_d,
...@@ -192,3 +192,5 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -192,3 +192,5 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads; return input_grads;
} }
} // namespace transformer_engine::pytorch
...@@ -4,24 +4,13 @@ ...@@ -4,24 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common.h"
#include "extensions.h" #include "extensions.h"
#include "pybind.h" #include "pybind.h"
constexpr int block_size = 512; namespace {
// get the fused attention backend constexpr int block_size = 512;
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
// fast zero-fills of tensors // fast zero-fills of tensors
void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) {
...@@ -62,6 +51,23 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe ...@@ -62,6 +51,23 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe
return philox_args; return philox_args;
} }
} // namespace
namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
}
// fused attention FWD with separate Q, K and V tensors // fused attention FWD with separate Q, K and V tensors
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
...@@ -74,8 +80,6 @@ std::vector<py::object> fused_attn_fwd( ...@@ -74,8 +80,6 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TensorWrapper te_Q, te_K, te_V, te_O, te_S; TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none(); auto none = py::none();
...@@ -87,8 +91,8 @@ std::vector<py::object> fused_attn_fwd( ...@@ -87,8 +91,8 @@ std::vector<py::object> fused_attn_fwd(
te_V = makeTransformerEngineTensor(V, none); te_V = makeTransformerEngineTensor(V, none);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const transformer_engine::DType qkv_type = te_Q.dtype(); const DType qkv_type = te_Q.dtype();
const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape()); std::vector<size_t> k_shape = convertShape(te_K.shape());
...@@ -260,13 +264,11 @@ std::vector<py::object> fused_attn_bwd( ...@@ -260,13 +264,11 @@ std::vector<py::object> fused_attn_bwd(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto none = py::none(); auto none = py::none();
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
te_Q = makeTransformerEngineTensor(Q, none); te_Q = makeTransformerEngineTensor(Q, none);
...@@ -276,8 +278,8 @@ std::vector<py::object> fused_attn_bwd( ...@@ -276,8 +278,8 @@ std::vector<py::object> fused_attn_bwd(
te_dO = makeTransformerEngineTensor(dO, none); te_dO = makeTransformerEngineTensor(dO, none);
// qkv type from the te_Q // qkv type from the te_Q
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer); std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
const transformer_engine::DType qkv_type = te_Q.dtype(); const DType qkv_type = te_Q.dtype();
const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
py::object s_python, dp_python; py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
...@@ -497,9 +499,6 @@ std::vector<py::object> fused_attn_bwd( ...@@ -497,9 +499,6 @@ std::vector<py::object> fused_attn_bwd(
} }
at::Tensor fa_prepare_fwd(at::Tensor qkvi) { at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half ||
qkvi.scalar_type() == at::ScalarType::BFloat16); qkvi.scalar_type() == at::ScalarType::BFloat16);
...@@ -521,9 +520,6 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) { ...@@ -521,9 +520,6 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
} }
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(q.is_contiguous()); NVTE_CHECK(q.is_contiguous());
NVTE_CHECK(k.is_contiguous()); NVTE_CHECK(k.is_contiguous());
NVTE_CHECK(v.is_contiguous()); NVTE_CHECK(v.is_contiguous());
...@@ -556,9 +552,6 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { ...@@ -556,9 +552,6 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx) { int half_idx) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
...@@ -600,9 +593,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -600,9 +593,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) { const at::Tensor &cu_seqlens, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
...@@ -647,9 +637,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st ...@@ -647,9 +637,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed, int second_half_lse_seqlen) { bool lse_packed, int second_half_lse_seqlen) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
...@@ -700,9 +687,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ ...@@ -700,9 +687,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed) { bool only_second_half, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_out = makeTransformerEngineTensor(out); auto te_out = makeTransformerEngineTensor(out);
auto te_out_per_step = makeTransformerEngineTensor(out_per_step); auto te_out_per_step = makeTransformerEngineTensor(out_per_step);
auto te_lse = makeTransformerEngineTensor(lse); auto te_lse = makeTransformerEngineTensor(lse);
...@@ -720,9 +704,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at ...@@ -720,9 +704,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half, const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) { const std::string &second_half) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_grad = makeTransformerEngineTensor(grad); auto te_grad = makeTransformerEngineTensor(grad);
auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step); auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
...@@ -737,9 +718,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, ...@@ -737,9 +718,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank) { int world_size, int rank) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2); NVTE_CHECK(cu_seqlens.size(0) >= 2);
...@@ -766,9 +744,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -766,9 +744,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
**************************************************************************************************/ **************************************************************************************************/
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) { at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int h = tensor.size(1); int h = tensor.size(1);
int d = tensor.size(2); int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d}; std::vector<int64_t> shape = {b, max_seq_len, h, d};
...@@ -789,9 +764,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, ...@@ -789,9 +764,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/ **************************************************************************************************/
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int max_seq_len = tensor.size(1); int max_seq_len = tensor.size(1);
int h = tensor.size(2); int h = tensor.size(2);
int d = tensor.size(3); int d = tensor.size(3);
...@@ -812,9 +784,6 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at ...@@ -812,9 +784,6 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) { int max_pages_per_seq, bool is_non_paged) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() && new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(), new_k.scalar_type() == k_cache.scalar_type(),
...@@ -836,3 +805,5 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at ...@@ -836,3 +805,5 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged, qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) { size_t w, size_t start_offset, size_t block_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
...@@ -29,9 +28,6 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor ...@@ -29,9 +28,6 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len, size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) { const transformer_engine::DType out_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
...@@ -51,3 +47,5 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const ...@@ -51,3 +47,5 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1), inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream()); start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
namespace { namespace {
void* get_data_ptr(MaybeTensor tensor) { void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) {
if (tensor.has_value()) return tensor->data_ptr(); if (tensor.has_value()) return tensor->data_ptr();
return nullptr; return nullptr;
} }
size_t get_size(MaybeTensor tensor, int dim) { size_t get_size(transformer_engine::pytorch::MaybeTensor tensor, int dim) {
if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim)); if (tensor.has_value()) return static_cast<size_t>(tensor->size(dim));
return 0; return 0;
} }
...@@ -271,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -271,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out; return out;
} }
} // namespace transformer_engine::pytorch void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, transformer_engine::DType B_type, at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D, bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) { bool gemm_producer, at::Tensor counter) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
// TODO: Handle scaling modes // TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
...@@ -326,13 +320,10 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -326,13 +320,10 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb, std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type, std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias, std::vector<at::Tensor> bias, DType bias_type, bool single_output,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out, std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector; te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers; std::vector<TensorWrapper> wrappers;
...@@ -450,3 +441,5 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -450,3 +441,5 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}); });
return bias; return bias;
} }
} // namespace transformer_engine::pytorch
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
size_t get_cublasLt_version() { return cublasLtGetVersion(); } size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); } size_t get_cudnn_version() { return cudnnGetVersion(); }
} // namespace transformer_engine::pytorch
...@@ -6,14 +6,13 @@ ...@@ -6,14 +6,13 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay) { const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -29,9 +28,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -29,9 +28,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) { const int bias_correction, const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -48,9 +44,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -48,9 +44,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) { const float weight_decay, DType fp8_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -68,9 +61,6 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -68,9 +61,6 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
const float epsilon, at::Tensor step, const int mode, const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
at::Tensor inv_scale) { at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -91,9 +81,6 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl ...@@ -91,9 +81,6 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
const float epsilon, at::Tensor step, const int mode, const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
at::Tensor inv_scale) { at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -107,3 +94,5 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl ...@@ -107,3 +94,5 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -6,12 +6,11 @@ ...@@ -6,12 +6,11 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_compute_scale_and_scale_inv_cuda( void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) { float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -21,3 +20,5 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -21,3 +20,5 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -6,12 +6,11 @@ ...@@ -6,12 +6,11 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) { at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
...@@ -57,9 +56,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -57,9 +56,6 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) { at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
...@@ -105,3 +101,5 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -105,3 +101,5 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
} // namespace transformer_engine::pytorch
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) { std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -19,3 +18,5 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -19,3 +18,5 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -6,13 +6,12 @@ ...@@ -6,13 +6,12 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd, std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run, float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) { bool wd_after_momentum, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
...@@ -22,3 +21,5 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -22,3 +21,5 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run, num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
...@@ -9,28 +9,11 @@ ...@@ -9,28 +9,11 @@
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin, const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous(); const auto &mu_ = mu.contiguous();
...@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_); auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_); auto dgamma = at::empty_like(gamma_);
auto dbeta = at::empty_like(gamma_); auto dbeta = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_); auto x_cu = makeTransformerEngineTensor(x_);
...@@ -80,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -80,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
DType out_dtype, const int sm_margin, DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
...@@ -135,7 +116,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -135,7 +116,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
...@@ -202,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -202,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma, const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous(); const auto &rsigma_ = rsigma.contiguous();
...@@ -210,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -210,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto dx = at::empty_like(x_); auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_); auto dgamma = at::empty_like(gamma_);
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_); auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_); auto x_cu = makeTransformerEngineTensor(x_);
...@@ -244,12 +224,9 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -244,12 +224,9 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
} }
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer, py::object out, py::handle quantizer, DType out_dtype,
transformer_engine::DType out_dtype, const int sm_margin, const int sm_margin, const bool zero_centered_gamma) {
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
...@@ -297,7 +274,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -297,7 +274,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
transformer_engine::TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(), workspace.data(),
...@@ -360,3 +337,5 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -360,3 +337,5 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
} }
} // namespace transformer_engine::pytorch
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <torch/cuda.h> #include <torch/cuda.h>
#include <torch/extension.h> #include <torch/extension.h>
namespace nvshmem_api { namespace transformer_engine::pytorch {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) { void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM #ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {}; nvshmemx_init_attr_t attr = {};
...@@ -126,4 +127,5 @@ void nvshmem_finalize() { ...@@ -126,4 +127,5 @@ void nvshmem_finalize() {
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif #endif
} }
} // namespace nvshmem_api
} // namespace transformer_engine::pytorch
...@@ -7,12 +7,11 @@ ...@@ -7,12 +7,11 @@
#include "extensions.h" #include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch {
void fused_multi_row_padding(at::Tensor input, at::Tensor output, void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) { std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match."); "Number of input row list and padded row list must match.");
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
...@@ -22,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -22,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Extract properties from PyTorch tensors // Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list; std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list; std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
std::vector<transformer_engine::DType> input_type_list; std::vector<DType> input_type_list;
void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr()); void* d_input_ptr = reinterpret_cast<void*>(input.data_ptr());
void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr()); void* d_output_ptr = reinterpret_cast<void*>(output.data_ptr());
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
...@@ -52,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -52,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Construct TE tensors // Construct TE tensors
std::vector<NVTETensor> nvte_input_list, nvte_output_list; std::vector<NVTETensor> nvte_input_list, nvte_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape, auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector<size_t>& shape,
transformer_engine::DType dtype) -> NVTETensor { DType dtype) -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
return tensor_wrappers.back().data(); return tensor_wrappers.back().data();
}; };
...@@ -81,3 +80,5 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -81,3 +80,5 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream()); padded_num_rows_list.data(), at::cuda::getCurrentCUDAStream());
}); });
} }
} // namespace transformer_engine::pytorch
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch {
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) { std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
const int num_tokens = input.size(0); const int num_tokens = input.size(0);
int num_cols = input.size(1); int num_cols = input.size(1);
const int topK = indices.size(1); const int topK = indices.size(1);
...@@ -71,28 +72,23 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -71,28 +72,23 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
dtype); dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor( auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)}, sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
transformer_engine::DType::kInt32); DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(),
row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), row_id_map_cu.data(), TensorWrapper().data(), TensorWrapper().data(),
transformer_engine::TensorWrapper().data(), TensorWrapper().data(), num_tokens, topK, num_cols, num_out_tokens, stream);
transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols,
num_out_tokens, stream);
return std::make_tuple(permuted_output, row_id_map, workspace); return std::make_tuple(permuted_output, row_id_map, workspace);
} }
at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK) {
int64_t topK) {
return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK);
} }
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor prob, int64_t num_tokens, int64_t topK) {
int64_t topK) {
using namespace transformer_engine::pytorch;
int num_cols = input.size(1); int num_cols = input.size(1);
// Output buffer alloc // Output buffer alloc
...@@ -121,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -121,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
} }
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype, const DType dtype, at::Tensor row_id_map,
at::Tensor row_id_map, at::Tensor prob) { at::Tensor prob) {
using namespace transformer_engine::pytorch;
const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1); int num_cols = input_bwd.size(1);
...@@ -153,9 +148,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -153,9 +148,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto prob_cu = makeTransformerEngineTensor(prob); auto prob_cu = makeTransformerEngineTensor(prob);
auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad);
nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), TensorWrapper().data(),
row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(),
num_tokens, topK, num_cols, 0, stream); num_tokens, topK, num_cols, 0, stream);
return std::make_tuple(act_grad, prob_grad); return std::make_tuple(act_grad, prob_grad);
} }
} // namespace transformer_engine::pytorch
...@@ -110,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -110,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
...@@ -159,170 +155,188 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -159,170 +155,188 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer")); py::arg("quantizer"));
// Permutation functions // Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd, "MOE permute FWD", m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("moe_permute_bwd", moe_permute_bwd, "MOE permute BWD", m.def("moe_permute_bwd", transformer_engine::pytorch::moe_permute_bwd, "MOE permute BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_fwd", moe_unpermute_fwd, "MOE unpermute FWD", m.def("moe_unpermute_fwd", transformer_engine::pytorch::moe_unpermute_fwd, "MOE unpermute FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("moe_unpermute_bwd", moe_unpermute_bwd, "MOE unpermute BWD", m.def("moe_unpermute_bwd", transformer_engine::pytorch::moe_unpermute_bwd, "MOE unpermute BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// Softmax functions // Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", m.def("scaled_softmax_forward", &transformer_engine::pytorch::scaled_softmax_forward,
py::call_guard<py::gil_scoped_release>()); "Scaled Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD", m.def("scaled_softmax_backward", &transformer_engine::pytorch::scaled_softmax_backward,
py::call_guard<py::gil_scoped_release>()); "Scaled Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward, m.def("scaled_masked_softmax_forward",
"Scaled Masked Softmax FWD", py::call_guard<py::gil_scoped_release>()); &transformer_engine::pytorch::scaled_masked_softmax_forward, "Scaled Masked Softmax FWD",
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward, py::call_guard<py::gil_scoped_release>());
"Scaled Masked Softmax BWD", py::call_guard<py::gil_scoped_release>()); m.def("scaled_masked_softmax_backward",
m.def("scaled_upper_triang_masked_softmax_forward", &scaled_upper_triang_masked_softmax_forward, &transformer_engine::pytorch::scaled_masked_softmax_backward, "Scaled Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_forward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>()); "Scaled Upper-Triangular Masked Softmax FWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_upper_triang_masked_softmax_backward", &scaled_upper_triang_masked_softmax_backward, m.def("scaled_upper_triang_masked_softmax_backward",
&transformer_engine::pytorch::scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>()); "Scaled Upper-Triangular Masked Softmax BWD", py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_forward", m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward, &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD", "Scaled Bottom-Right Corner Aligned Masked Softmax FWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("scaled_aligned_causal_masked_softmax_backward", m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward, &transformer_engine::pytorch::scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD", "Scaled Bottom-Right Corner Aligned Masked Softmax BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// Other granular functions // Other granular functions
m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"),
py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("sm_margin"), py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"),
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype")); py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend,
"Get Fused Attention backend", py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute absolute max value in tensor", py::arg("input"), m.def("fused_amax_and_scale_update_after_reduction",
py::arg("amax"), py::call_guard<py::gil_scoped_release>()); &transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction", "Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax", &fp8_block_scaling_compute_partial_amax, m.def("fp8_block_scaling_compute_partial_amax",
&transformer_engine::pytorch::fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"), "Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_partial_cast", &fp8_block_scaling_partial_cast, m.def("fp8_block_scaling_partial_cast",
&transformer_engine::pytorch::fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>()); py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
// attention kernels // attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
py::call_guard<py::gil_scoped_release>()); "Prepare QKV for Flash Attention", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", m.def("fa_prepare_bwd", &transformer_engine::pytorch::fa_prepare_bwd,
"Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd, m.def("fused_attn_fwd", &transformer_engine::pytorch::fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &transformer_engine::pytorch::fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache", m.def("copy_to_kv_cache", &transformer_engine::pytorch::copy_to_kv_cache,
py::call_guard<py::gil_scoped_release>()); "Copy new KV tokens to KV cache", py::call_guard<py::gil_scoped_release>());
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD", m.def("convert_thd_to_bshd", &transformer_engine::pytorch::convert_thd_to_bshd,
py::call_guard<py::gil_scoped_release>()); "Convert a tensor from THD to BSHD", py::call_guard<py::gil_scoped_release>());
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD", m.def("convert_bshd_to_thd", &transformer_engine::pytorch::convert_bshd_to_thd,
py::call_guard<py::gil_scoped_release>()); "Convert a tesnor from BSHD to THD", py::call_guard<py::gil_scoped_release>());
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", m.def("fused_rope_forward", &transformer_engine::pytorch::fused_rope_forward,
py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version", m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
// Support THD format for Context Parallel // Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor, m.def("thd_read_half_tensor", &transformer_engine::pytorch::thd_read_half_tensor,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor", "tensor",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_second_half_lse_correction", &thd_second_half_lse_correction, m.def("thd_second_half_lse_correction",
&transformer_engine::pytorch::thd_second_half_lse_correction,
"Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>()); "Correct the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_read_second_half_lse", &thd_read_second_half_lse, m.def("thd_read_second_half_lse", &transformer_engine::pytorch::thd_read_second_half_lse,
"Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>()); "Read the second half of the softmax_lse", py::call_guard<py::gil_scoped_release>());
m.def("thd_out_correction", &thd_out_correction, m.def("thd_out_correction", &transformer_engine::pytorch::thd_out_correction,
"Correct the THD format output of context parallelism in forward pass", "Correct the THD format output of context parallelism in forward pass",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_grad_correction", &thd_grad_correction, m.def("thd_grad_correction", &transformer_engine::pytorch::thd_grad_correction,
"Correct the THD format gradients of context parallelism in backward pass", "Correct the THD format gradients of context parallelism in backward pass",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format", "Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// nvshmem functions // nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend, m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups", "Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor, m.def("create_nvshmem_tensor", &transformer_engine::pytorch::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>()); "Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream, m.def("nvshmem_send_on_current_stream",
&transformer_engine::pytorch::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream", "Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream, m.def("nvshmem_wait_on_current_stream",
&transformer_engine::pytorch::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA " "Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream", "stream",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize, m.def("nvshmem_finalize", &transformer_engine::pytorch::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources", "Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// multi-tensor functions // multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &transformer_engine::pytorch::multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors", "Fused overflow check + scale for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &transformer_engine::pytorch::multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors", "Computes L2 norm for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, m.def("multi_tensor_unscale_l2norm",
&transformer_engine::pytorch::multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)", "performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &multi_tensor_adam_cuda, m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer", "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda, m.def("multi_tensor_adam_param_remainder",
&transformer_engine::pytorch::multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer" "Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits", "where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer", "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, m.def("multi_tensor_adam_capturable",
&transformer_engine::pytorch::multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling", "support and LR scheduling",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, m.def("multi_tensor_adam_capturable_master",
&transformer_engine::pytorch::multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights", "support, LR scheduling and FP32 master weights",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &transformer_engine::pytorch::multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors", "Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda, m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>()); "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Data structures // Data structures
......
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
#include "common/common.h" #include "common/common.h"
#include "extensions.h" #include "extensions.h"
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { namespace transformer_engine::pytorch {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
auto input_tensor = tensor.contiguous(); auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
...@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), nullptr, te_input.shape(),
transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>()); amax.data_ptr<float>());
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
...@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
std::vector<at::Tensor> amax_histories, std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales, std::vector<at::Tensor> scales,
const std::string& amax_compute_algo, const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype, DType fp8_dtype, float margin) {
float margin) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors); std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors); std::vector<Tensor> t_scales(num_tensors);
...@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment