Unverified Commit 7b94bd99 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[common] Added support of FP4 data type (#1779)



* Added support of FP4 data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Refactoring to BitsNum in progress
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed compilation errors. All C++ tests passed
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed a typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added FP4 guard to TMA tensor descriptor data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed errors in JAX C++ extensions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed dummy NVFP4 C++ test file
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Make pytorch changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Refactored the code per the review notes. Fixed JAX build error.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed unnecessary static casts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Typo fix
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Pass correct num bits to create_2D_tensor_map; fixes CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* inline funcs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e963e4a9
...@@ -67,7 +67,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const ...@@ -67,7 +67,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends. // Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){ if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
compute_t g = static_cast<compute_t>(gamma); compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) { if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f); g += static_cast<compute_t>(1.f);
......
...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) { ...@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return true; return true;
} }
size_t typeToSize(DType type) { size_t typeToNumBits(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{ {
return TypeInfo<T>::size; return TypeInfo<T>::size;
...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) { ...@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"}, {DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"}, {DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}}; {DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
return name_map.at(type); return name_map.at(type);
} }
...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){ ...@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct scale_inv_meta { struct scale_inv_meta {
std::vector<size_t> shape; std::vector<size_t> shape;
DType type; DType type;
size_t type_size; size_t type_size_bits;
size_t bytes() const noexcept {
return (product(shape) * type_size_bits) / 8;
}
}; };
size_t bytes(const NVTEShape& shape, const DType type) {
return (product(shape) * typeToNumBits(type)) / 8;
}
NVTEShape convertShape(const std::vector<size_t>& s) { NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size()); return nvte_make_shape(s.data(), s.size());
} }
...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret; scale_inv_meta ret;
ret.shape = {1}; ret.shape = {1};
ret.type = DType::kFloat32; ret.type = DType::kFloat32;
ret.type_size = sizeof(float); ret.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret, ret}; return {ret, ret};
} }
if (scaling_mode == NVTE_MXFP8_1D_SCALING) { if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat8E8M0; ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type_size = sizeof(uint8_t); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32; ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float); ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size = sizeof(float); ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise}; return {ret_rowwise, ret_colwise};
} }
...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name, ...@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_.seed(seed); gen_.seed(seed);
rowwise_ = rowwise; rowwise_ = rowwise;
columnwise_ = columnwise; columnwise_ = columnwise;
size_t s = typeToSize(type); size_t total_size = bytes(shape, type);
size_t total_size = product(shape) * s;
void *dptr_rowwise = nullptr; void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr; void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr; cpu_data_rowwise_ = nullptr;
...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name, ...@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
} else { } else {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode()); get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) { if (rowwise) {
...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name, ...@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void Tensor::to_cpu() const { void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(), cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr, tensor_.get_rowwise_data().data_ptr,
...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const { ...@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr, tensor_.get_rowwise_scale_inv().data_ptr,
scale_size, scale_size,
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr, tensor_.get_columnwise_scale_inv().data_ptr,
scale_size, scale_size,
...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const { ...@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void Tensor::from_cpu() const { void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) { if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (isFp8Type(dtype())) { if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){ if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpyHostToDevice);
} }
auto [rowwise_scale_meta, colwise_scale_meta] = auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode()); get_scales(s, tensor_.scaling_mode());
if (rowwise_) { if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size, rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
} }
if (columnwise_) { if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size, columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
......
...@@ -10,10 +10,15 @@ ...@@ -10,10 +10,15 @@
#include <vector> #include <vector>
#include <array> #include <array>
#include <random> #include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
...@@ -55,19 +60,32 @@ using bf16 = nv_bfloat16; ...@@ -55,19 +60,32 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t; using fp8e8m0 = uint8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
template <typename T> template <typename T>
struct TypeInfo{ struct BitsNumber;
using types = std::tuple<byte,
int16, #if FP4_TYPE_SUPPORTED
int32, template <>
int64, struct BitsNumber<fp4e2m1> {
fp32, static constexpr size_t num_bits = 4;
fp16, };
bf16, #endif
fp8e4m3,
fp8e5m2, template <typename T>
fp8e8m0>; struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
#endif
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -94,7 +112,7 @@ struct TypeInfo{ ...@@ -94,7 +112,7 @@ struct TypeInfo{
} }
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = BitsNumber<T>::num_bits;;
}; };
class Tensor { class Tensor {
...@@ -416,9 +434,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } ...@@ -416,9 +434,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type); size_t typeToNumBits(DType type);
size_t product(const NVTEShape &shape); size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape); size_t product(const std::vector<size_t> &shape);
size_t bytes(const NVTEShape& shape, const DType type);
size_t first_dimension(const std::vector<size_t> &shape); size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape); size_t last_dimension(const std::vector<size_t> &shape);
...@@ -464,6 +483,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -464,6 +483,16 @@ constexpr int32_t blackwellComputeCapability = 100;
} // namespace test } // namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -515,8 +544,16 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -515,8 +544,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
break; \ break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type."); \ printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
...@@ -535,7 +572,15 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -535,7 +572,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
...@@ -560,5 +605,5 @@ constexpr int32_t blackwellComputeCapability = 100; ...@@ -560,5 +605,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \ } \
break; \ break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type MARKED TEST 4."); \
} }
...@@ -196,7 +196,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -196,7 +196,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
if (param_type == NVTETensorParam::kNVTERowwiseData || if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) { param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer // Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype); param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype);
param_shape = chunk_shape; param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData && if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
...@@ -217,7 +217,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz ...@@ -217,7 +217,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else { } else {
chunk_scale_height /= 32; chunk_scale_height /= 32;
} }
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype); param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width}; param_shape = {chunk_scale_height, chunk_scale_width};
} }
...@@ -236,7 +236,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source ...@@ -236,7 +236,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer // Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + chunk_offset * _ubuf.element_size();
if (chunk.dptr() != nullptr) { if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape()); chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
} }
...@@ -269,7 +269,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -269,7 +269,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
"or 2 (multi-atomic)."); "or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr; void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
...@@ -306,7 +306,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -306,7 +306,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0));
// Communication: AG and RS // Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) { if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event); (cudaEvent_t)_comm_launch_event);
...@@ -606,7 +606,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -606,7 +606,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer // Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size; int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size; _num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) { if (_is_reduce_scatter) {
...@@ -704,7 +704,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ...@@ -704,7 +704,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].bytes();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory // Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr; void *D_buffer_ptr;
...@@ -762,21 +762,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( ...@@ -762,21 +762,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
if (B_copy.numel() > 0) { if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), _ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice,
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), _stream_send[0]));
cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
} }
// Copy the first GEMM output chunk to the end chunk position of D_buffer // Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr()); char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main)); cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer // Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(),
cudaMemcpyDeviceToDevice, stream_main)); cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation // Clean up buffer allocation
...@@ -806,7 +805,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -806,7 +805,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const size_t n_chunk = _ubufs[0].size(0); const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].bytes();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
...@@ -882,8 +881,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -882,8 +881,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
cudaMemcpyDeviceToDevice, _stream_send[0])); _stream_send[0]));
} }
} }
} else { } else {
...@@ -935,8 +934,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -935,8 +934,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
cudaMemcpyDeviceToDevice, _stream_send[0])); _stream_send[0]));
} }
} }
} }
...@@ -966,7 +965,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( ...@@ -966,7 +965,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes // Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].bytes();
// Reset counters // Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr()); int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
...@@ -1033,7 +1032,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1033,7 +1032,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
size_t m = transa ? A.size(0) : A.size(1); size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0); size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0); size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].bytes();
// Get input and workspace data pointers // Get input and workspace data pointers
size_t input_chunk_size = n_chunk * k; size_t input_chunk_size = n_chunk * k;
......
...@@ -116,13 +116,20 @@ void checkCuDriverContext(CUstream stream) { ...@@ -116,13 +116,20 @@ void checkCuDriverContext(CUstream stream) {
} }
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = { static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = []() {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, std::unordered_map<DType, CUtensorMapDataType> typeMapping = {
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
#if FP4_TYPE_SUPPORTED
typeMapping.insert(
{DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B});
#endif
return typeMapping;
}();
return dtypeMapping.at(dtype); return dtypeMapping.at(dtype);
} }
...@@ -130,7 +137,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { ...@@ -130,7 +137,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems, const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size) { const uint32_t offset_elems, const size_t type_num_bits) {
// Get a function pointer to the cuTensorMapEncodeTiled driver API // Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
...@@ -142,7 +149,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -142,7 +149,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint64_t size[rank] = {globalX, globalY}; uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next // The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t stride[rank - 1] = {stride_elems * type_size}; uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8};
// The boxSize is the size of the shared memory buffer that is used as the // The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer // source/destination of a TMA transfer
...@@ -152,15 +159,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -152,15 +159,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint32_t elemStride[rank] = {1, 1}; uint32_t elemStride[rank] = {1, 1};
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
void *dataPtr = void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size); (offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned"); "Tensor data pointer must be 16B aligned");
const int TMA_needed_size = TMA_gmem_alignment / type_size; const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
// Create the tensor descriptor. // Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
...@@ -206,4 +213,18 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor ...@@ -206,4 +213,18 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
return ret; return ret;
} }
size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) {
return (elements_num * typeToNumBits(buffer_dtype)) / 8;
}
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype) {
if (buffer_dtype == DType::kFloat4E2M1) {
NVTE_CHECK(dim_last % 2 == 0,
"Last dimension of a tensor with FP4 type of data must be an even number!");
}
const size_t elements_num = dim_first * dim_last;
return get_buffer_size_bytes(elements_num, buffer_dtype);
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -8,9 +8,15 @@ ...@@ -8,9 +8,15 @@
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
...@@ -183,6 +189,7 @@ struct Tensor { ...@@ -183,6 +189,7 @@ struct Tensor {
} }
break; break;
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape; return columnwise_data.shape;
} else { } else {
...@@ -268,6 +275,13 @@ constexpr T DIVUP(const T &x, const T &y) { ...@@ -268,6 +275,13 @@ constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y)); return (((x) + ((y)-1)) / (y));
} }
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
"Integral type required.");
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
using byte = uint8_t; using byte = uint8_t;
using int16 = int16_t; using int16 = int16_t;
using int32 = int32_t; using int32 = int32_t;
...@@ -280,6 +294,9 @@ using fp8e5m2 = __nv_fp8_e5m2; ...@@ -280,6 +294,9 @@ using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0; using fp8e8m0 = __nv_fp8_e8m0;
#endif #endif
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
namespace detail { namespace detail {
...@@ -303,11 +320,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) ...@@ -303,11 +320,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif #endif
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME #undef TRANSFORMER_ENGINE_TYPE_NAME
template <typename T> template <typename T>
struct TypeExtrema; struct TypeExtrema;
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
static constexpr float max = 6.0f;
};
#endif
template <> template <>
struct TypeExtrema<fp8e4m3> { struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f; static constexpr float max = 448.0f;
...@@ -337,9 +364,28 @@ struct TypeExtrema { ...@@ -337,9 +364,28 @@ struct TypeExtrema {
} // namespace detail } // namespace detail
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
#endif
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -364,11 +410,21 @@ struct TypeInfo { ...@@ -364,11 +410,21 @@ struct TypeInfo {
} }
constexpr static DType dtype = getType<T>(); constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T); constexpr static size_t size = BitsNumber<T>::num_bits;
constexpr static float max_finite_value = detail::TypeExtrema<T>::max; constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
constexpr static const char *name = detail::type_name<T>(); constexpr static const char *name = detail::type_name<T>();
}; };
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -412,6 +468,7 @@ struct TypeInfo { ...@@ -412,6 +468,7 @@ struct TypeInfo {
using type = byte; \ using type = byte; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} break; \ } break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -523,6 +580,9 @@ struct TypeInfo { ...@@ -523,6 +580,9 @@ struct TypeInfo {
case DType::kFloat8E4M3: { \ case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \ NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \ } break; \
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -593,6 +653,14 @@ struct is_fp8<fp8e4m3> : std::true_type {}; ...@@ -593,6 +653,14 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <> template <>
struct is_fp8<fp8e5m2> : std::true_type {}; struct is_fp8<fp8e5m2> : std::true_type {};
template <typename T>
struct is_fp4 : std::false_type {};
#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
...@@ -611,13 +679,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { ...@@ -611,13 +679,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
} }
size_t typeToSize(const DType type); size_t typeToSize(const DType type);
size_t typeToNumBits(const DType type);
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
/*! \brief Update a tensor's FP8 scale-inverse /*! \brief Update a tensor's FP8 scale-inverse
* *
* The FP8 scale-inverse (dequantization scaling factor) is updated * The FP8 scale-inverse (dequantization scaling factor) is updated
...@@ -636,7 +707,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype); ...@@ -636,7 +707,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems, const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size); const uint32_t offset_elems, const size_t type_num_bits);
bool is_supported_by_CC_100(); bool is_supported_by_CC_100();
......
...@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor ...@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
int batch = cu_seqlens_shape[0] - 1; int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1]; int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2]; int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype()); int hidden_size_in_bytes = (num_heads * dim_per_head * typeToNumBits(tensor.dtype())) / 8;
// For 128-bits load/store // For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0); NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
...@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, ...@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head); NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head; size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0); NVTE_CHECK(((hidden_size * typeToNumBits(grad.dtype())) / 8) % 16 == 0);
constexpr unsigned int block = 256; constexpr unsigned int block = 256;
unsigned int grid_x; unsigned int grid_x;
......
...@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) { if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv)); size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
...@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) { if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv)); size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
...@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrQ = static_cast<void *>(devPtrQKV); void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrQ = devPtrQKV; void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......
...@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrQ = static_cast<void*>(devPtrQKV); void* devPtrQ = static_cast<void*>(devPtrQKV);
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride); void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
...@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrQ = devPtrQKV; void* devPtrQ = devPtrQKV;
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride); void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
...@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrK = devPtrKV; void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride); void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
...@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked( ...@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrK = devPtrKV; void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride); void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
......
...@@ -22,17 +22,18 @@ extern "C" { ...@@ -22,17 +22,18 @@ extern "C" {
* \brief TE datatype. * \brief TE datatype.
*/ */
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt16 = 1, /*!< 16-bit integer */ kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */ kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */ kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */ kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */ kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTENumTypes /*!< Number of supported types */
}; };
/*! \struct NVTEShape /*! \struct NVTEShape
...@@ -87,6 +88,10 @@ enum NVTEScalingMode { ...@@ -87,6 +88,10 @@ enum NVTEScalingMode {
*/ */
NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3, NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
NVTE_INVALID_SCALING = 100 NVTE_INVALID_SCALING = 100
}; };
...@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor); ...@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
*/ */
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get the byte size for the tensor.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor.
*/
size_t nvte_tensor_size_bytes(const NVTETensor tensor);
/*! \brief Get a tensor's total number of elements. /*! \brief Get a tensor's total number of elements.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor); ...@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor);
*/ */
size_t nvte_tensor_element_size(const NVTETensor tensor); size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get the bit size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Bit size of the tensor's data type.
*/
size_t nvte_tensor_element_size_bits(const NVTETensor tensor);
/*! \brief Get a tensor's data type. /*! \brief Get a tensor's data type.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -390,6 +411,7 @@ enum class DType { ...@@ -390,6 +411,7 @@ enum class DType {
kFloat8E4M3 = 7, kFloat8E4M3 = 7,
kFloat8E5M2 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9, kFloat8E8M0 = 9,
kFloat4E2M1 = 10,
kNumTypes kNumTypes
}; };
...@@ -398,7 +420,16 @@ enum class DType { ...@@ -398,7 +420,16 @@ enum class DType {
* Return true if TE datatype is FP8 * Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest * \param[in] DType TE Datatype of interest
*/ */
bool is_fp8_dtype(const DType t); inline bool is_fp8_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2;
}
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
...@@ -627,6 +658,15 @@ class TensorWrapper { ...@@ -627,6 +658,15 @@ class TensorWrapper {
return nvte_tensor_element_size(tensor_); return nvte_tensor_element_size(tensor_);
} }
/*! \brief Get the tensor's element size in bits.
*
* \return Element size in bits.
*/
size_t element_size_bits() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size_bits(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype. * data even if the TensorWrapper has a non-zero shape and valid dtype.
* *
...@@ -634,7 +674,7 @@ class TensorWrapper { ...@@ -634,7 +674,7 @@ class TensorWrapper {
*/ */
size_t bytes() const noexcept { size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0; if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); return nvte_tensor_size_bytes(tensor_);
} }
/*! \brief Get the data type of this TensorWrapper. /*! \brief Get the data type of this TensorWrapper.
......
...@@ -212,8 +212,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -212,8 +212,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
} }
const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype; const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype;
NVTE_CHECK(gamma_dtype == DType::kFloat32 || gamma_dtype == DType::kFloat16 ||
gamma_dtype == DType::kBFloat16,
"Gamma of type FP4 is not supported");
_scalar_dptr = std::make_unique<char[]>(typeToSize(gamma_dtype)); _scalar_dptr = std::make_unique<char[]>(typeToNumBits(gamma_dtype) / 8);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gamma_dtype, cpp_dtype, gamma_dtype, cpp_dtype,
*(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;); *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
......
...@@ -18,12 +18,15 @@ ...@@ -18,12 +18,15 @@
namespace transformer_engine { namespace transformer_engine {
size_t typeToSize(const DType type) { size_t typeToNumBits(const DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;); // NOLINT(*) return TypeInfo<T>::size;); // NOLINT(*)
} }
bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } size_t typeToSize(const DType type) {
NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
return typeToNumBits(type) / 8;
}
std::string to_string(const DType type) { std::string to_string(const DType type) {
switch (type) { switch (type) {
...@@ -41,6 +44,8 @@ std::string to_string(const DType type) { ...@@ -41,6 +44,8 @@ std::string to_string(const DType type) {
return "Float8E5M2"; return "Float8E5M2";
case DType::kFloat8E8M0: case DType::kFloat8E8M0:
return "Float8E8M0"; return "Float8E8M0";
case DType::kFloat4E2M1:
return "Float4E2M1";
case DType::kInt32: case DType::kInt32:
return "Int32"; return "Int32";
case DType::kInt64: case DType::kInt64:
...@@ -56,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) { ...@@ -56,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING"; return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING"; return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
case NVTE_INVALID_SCALING: case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING"; return "NVTE_INVALID_SCALING";
} }
...@@ -85,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -85,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")"); t.columnwise_scale_inv.shape, ")");
} }
} else { } else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor // Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul}; auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment; size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_colwise = 32;
if (t.has_data()) { if (t.has_data()) {
alignment = block_alignment[0]; alignment = block_alignment[0];
...@@ -96,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -96,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment; DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1]; alignment = block_alignment[1];
expected_y = expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(32)), alignment) * alignment; DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y}; const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ", "\" has invalid scale_inv shape (expected ", expected, ", got ",
...@@ -105,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -105,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
alignment = block_alignment[1]; alignment = block_alignment[1];
expected_x = expected_x =
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(32)), alignment) * alignment; DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
alignment;
alignment = block_alignment[0]; alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y}; const auto &expected = std::vector<size_t>{expected_x, expected_y};
...@@ -384,10 +396,24 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { ...@@ -384,10 +396,24 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
return numel; return numel;
} }
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 8 * sizeof(float);
return transformer_engine::typeToNumBits(t->dtype());
}
size_t nvte_tensor_element_size(const NVTETensor tensor) { size_t nvte_tensor_element_size(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor); auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return sizeof(float); if (t == nullptr) return sizeof(float);
return transformer_engine::typeToSize(t->dtype()); NVTE_CHECK(!is_fp4_dtype(t->dtype()),
"For FP4 type please use the nvte_tensor_element_size_bits.");
return nvte_tensor_element_size_bits(tensor) / 8;
}
size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 0;
return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
...@@ -514,7 +540,7 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { ...@@ -514,7 +540,7 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated // Zero out tensor data if allocated
if (t.data.dptr != nullptr) { if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
} }
// Set amax to 0 if allocated // Set amax to 0 if allocated
......
...@@ -192,17 +192,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / ...@@ -192,17 +192,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} else { } else {
// Check that workspace matches expected size // Check that workspace matches expected size
const size_t workspace_size = const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) * std::multiplies<size_t>()),
typeToSize(workspace->data.dtype); workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())"); num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")"); ", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
} }
} }
......
...@@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
// Input matrices are divided into tiles // Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype); const int tile_dim_m = THREADS_PER_WARP * desired_store_size * 8 / typeToNumBits(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype); const int tile_dim_n = THREADS_PER_WARP * desired_load_size * 8 / typeToNumBits(itype);
// Add tensors to kernel argument struct // Add tensors to kernel argument struct
MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned; MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned;
......
...@@ -460,7 +460,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size ...@@ -460,7 +460,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
CUtensorMap tensor_map_output_trans{}; CUtensorMap tensor_map_output_trans{};
create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x, create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x,
/*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM, /*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM,
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType)); /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8);
return tensor_map_output_trans; return tensor_map_output_trans;
} }
......
...@@ -382,17 +382,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -382,17 +382,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} else { } else {
// Check that workspace matches expected size // Check that workspace matches expected size
const size_t workspace_size = const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) * std::multiplies<size_t>()),
typeToSize(workspace->data.dtype); workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())"); num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")"); ", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
} }
} }
......
...@@ -754,19 +754,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -754,19 +754,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, sizeof(IType)); cols, 0, typeToNumBits(gated_input.dtype()));
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
...@@ -849,31 +850,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -849,31 +850,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType)); SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0,
typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
if (USE_COLWISE_SCALING) { if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, sizeof(OType)); 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, sizeof(OType)); cols, typeToNumBits(output->dtype()));
} }
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
......
...@@ -895,15 +895,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T ...@@ -895,15 +895,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
alignas(64) CUtensorMap tensor_map_output{}; alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
} }
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType> cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
...@@ -991,24 +991,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -991,24 +991,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas(64) CUtensorMap tensor_map_output_colwise{}; alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(IType)); typeToNumBits(input.dtype()));
} }
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
if (use_colwise_scaling) { if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
...@@ -1101,7 +1101,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { ...@@ -1101,7 +1101,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim(); const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16; constexpr int TMA_bytes = 16;
const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
......
...@@ -319,9 +319,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -319,9 +319,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
alignas(64) CUtensorMap tensor_map_output{}; alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType)); SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(OType)); SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype()));
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X> dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr,
......
...@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
// Input matrices are divided into tiles // Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
// Add tensors to kernel argument struct // Add tensors to kernel argument struct
MultiPaddingArgs kernel_args; MultiPaddingArgs kernel_args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment