Unverified Commit 7b94bd99 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[common] Added support of FP4 data type (#1779)



* Added support of FP4 data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Refactoring to BitsNum in progress
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed compilation errors. All C++ tests passed
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed a typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added FP4 guard to TMA tensor descriptor data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed errors in JAX C++ extensions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed dummy NVFP4 C++ test file
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Make pytorch changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Refactored the code per the review notes. Fixed JAX build error.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed unnecessary static casts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Typo fix
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Pass correct num bits to create_2D_tensor_map; fixes CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* inline funcs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e963e4a9
......@@ -67,7 +67,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
......
......@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
return true;
}
size_t typeToSize(DType type) {
size_t typeToNumBits(DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
{
return TypeInfo<T>::size;
......@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}};
{DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
return name_map.at(type);
}
......@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
struct scale_inv_meta {
std::vector<size_t> shape;
DType type;
size_t type_size;
size_t type_size_bits;
size_t bytes() const noexcept {
return (product(shape) * type_size_bits) / 8;
}
};
size_t bytes(const NVTEShape& shape, const DType type) {
return (product(shape) * typeToNumBits(type)) / 8;
}
NVTEShape convertShape(const std::vector<size_t>& s) {
return nvte_make_shape(s.data(), s.size());
}
......@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret;
ret.shape = {1};
ret.type = DType::kFloat32;
ret.type_size = sizeof(float);
ret.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret, ret};
}
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
......@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size = sizeof(uint8_t);
ret_colwise.type_size = sizeof(uint8_t);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
return {ret_rowwise, ret_colwise};
}
......@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise};
}
......@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
}
ret_rowwise.type = DType::kFloat32;
ret_colwise.type = DType::kFloat32;
ret_rowwise.type_size = sizeof(float);
ret_colwise.type_size = sizeof(float);
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
return {ret_rowwise, ret_colwise};
}
......@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise;
size_t s = typeToSize(type);
size_t total_size = product(shape) * s;
size_t total_size = bytes(shape, type);
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
......@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
} else {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
auto columnwise_scale_shape = colwise_scale_meta.shape;
if (rowwise) {
......@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) {
cudaMemcpy(cpu_data_rowwise_.get(),
tensor_.get_rowwise_data().data_ptr,
......@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
tensor_.get_rowwise_scale_inv().data_ptr,
scale_size,
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
tensor_.get_columnwise_scale_inv().data_ptr,
scale_size,
......@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
const size_t size = bytes(s, tensor_.dtype());
if (rowwise_) {
cudaMemcpy(tensor_.get_rowwise_data().data_ptr,
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (columnwise_) {
cudaMemcpy(tensor_.get_columnwise_data().data_ptr,
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
if (isFp8Type(dtype())) {
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
if (tensor_.amax() != nullptr){
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
auto [rowwise_scale_meta, colwise_scale_meta] =
get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
auto scale_size = rowwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
rowwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
}
if (columnwise_) {
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
auto scale_size = colwise_scale_meta.bytes();
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
columnwise_scale_inv_cpu_data_.get(), scale_size,
cudaMemcpyHostToDevice);
......
......@@ -10,10 +10,15 @@
#include <vector>
#include <array>
#include <random>
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
......@@ -55,19 +60,32 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using fp8e8m0 = uint8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int16,
int32,
int64,
fp32,
fp16,
bf16,
fp8e4m3,
fp8e5m2,
fp8e8m0>;
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
#endif
template <typename U, DType current>
struct Helper {
......@@ -94,7 +112,7 @@ struct TypeInfo{
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
constexpr static size_t size = BitsNumber<T>::num_bits;;
};
class Tensor {
......@@ -416,9 +434,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
inline float srelu(const float x) { return x > 0 ? x * x : 0; }
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
size_t typeToSize(DType type);
size_t typeToNumBits(DType type);
size_t product(const NVTEShape &shape);
size_t product(const std::vector<size_t> &shape);
size_t bytes(const NVTEShape& shape, const DType type);
size_t first_dimension(const std::vector<size_t> &shape);
size_t last_dimension(const std::vector<size_t> &shape);
......@@ -464,6 +483,16 @@ constexpr int32_t blackwellComputeCapability = 100;
} // namespace test
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -515,8 +544,16 @@ constexpr int32_t blackwellComputeCapability = 100;
{__VA_ARGS__} \
} \
break; \
case DType::kFloat8E8M0: \
{ \
using type = fp8e8m0; \
{__VA_ARGS__} \
} \
break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
printf("dtype: %d\n", static_cast<int>(dtype)); \
NVTE_ERROR("Invalid type MARKED TEST."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
......@@ -535,7 +572,15 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type MARKED TEST 2."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type MARKED TEST 3."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
......@@ -560,5 +605,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type MARKED TEST 4."); \
}
......@@ -196,7 +196,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype);
param_dptr += get_buffer_size_bytes(chunk_offset, param_dtype);
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
......@@ -217,7 +217,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
} else {
chunk_scale_height /= 32;
}
param_dptr += (chunk_offset / 32) * typeToSize(param_dtype);
param_dptr += get_buffer_size_bytes(chunk_offset / 32, param_dtype);
param_shape = {chunk_scale_height, chunk_scale_width};
}
......@@ -236,7 +236,7 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size());
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + chunk_offset * _ubuf.element_size();
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
}
......@@ -269,7 +269,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
"or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
......@@ -306,7 +306,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0));
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
int comm_elements = _ubuf.bytes() / 2; // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
......@@ -606,7 +606,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) {
......@@ -704,7 +704,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
......@@ -762,21 +762,20 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + D.bytes(), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.bytes(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
......@@ -806,7 +805,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
......@@ -882,8 +881,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
} else {
......@@ -935,8 +934,8 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
}
......@@ -966,7 +965,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
......@@ -1033,7 +1032,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int comm_bytes = _ubufs[0].bytes();
// Get input and workspace data pointers
size_t input_chunk_size = n_chunk * k;
......
......@@ -116,13 +116,20 @@ void checkCuDriverContext(CUstream stream) {
}
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = []() {
std::unordered_map<DType, CUtensorMapDataType> typeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
#if FP4_TYPE_SUPPORTED
typeMapping.insert(
{DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B});
#endif
return typeMapping;
}();
return dtypeMapping.at(dtype);
}
......@@ -130,7 +137,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size) {
const uint32_t offset_elems, const size_t type_num_bits) {
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
......@@ -142,7 +149,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t stride[rank - 1] = {stride_elems * type_size};
uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8};
// The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer
......@@ -152,15 +159,15 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
uint32_t elemStride[rank] = {1, 1};
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
void *dataPtr =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = TMA_gmem_alignment / type_size;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size,
"-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
// Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
......@@ -206,4 +213,18 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
return ret;
}
size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) {
return (elements_num * typeToNumBits(buffer_dtype)) / 8;
}
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype) {
if (buffer_dtype == DType::kFloat4E2M1) {
NVTE_CHECK(dim_last % 2 == 0,
"Last dimension of a tensor with FP4 type of data must be an even number!");
}
const size_t elements_num = dim_first * dim_last;
return get_buffer_size_bytes(elements_num, buffer_dtype);
}
} // namespace transformer_engine
......@@ -8,9 +8,15 @@
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <cudaTypedefs.h>
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#if FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif
#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>
......@@ -183,6 +189,7 @@ struct Tensor {
}
break;
case NVTE_MXFP8_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
......@@ -268,6 +275,13 @@ constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
}
template <typename T1, typename T2>
constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T1 &N, const T2 &M) {
static_assert(std::is_integral<T1>::value && std::is_integral<T2>::value,
"Integral type required.");
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
......@@ -280,6 +294,9 @@ using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
using e8m0_t = uint8_t;
namespace detail {
......@@ -303,11 +320,21 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME
template <typename T>
struct TypeExtrema;
#if FP4_TYPE_SUPPORTED
template <>
struct TypeExtrema<fp4e2m1> {
static constexpr float max = 6.0f;
};
#endif
template <>
struct TypeExtrema<fp8e4m3> {
static constexpr float max = 448.0f;
......@@ -337,9 +364,28 @@ struct TypeExtrema {
} // namespace detail
template <typename T>
struct BitsNumber;
#if FP4_TYPE_SUPPORTED
template <>
struct BitsNumber<fp4e2m1> {
static constexpr size_t num_bits = 4;
};
#endif
template <typename T>
struct BitsNumber {
static constexpr size_t num_bits = 8 * sizeof(T);
};
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
#endif
template <typename U, DType current>
struct Helper {
......@@ -364,11 +410,21 @@ struct TypeInfo {
}
constexpr static DType dtype = getType<T>();
constexpr static size_t size = sizeof(T);
constexpr static size_t size = BitsNumber<T>::num_bits;
constexpr static float max_finite_value = detail::TypeExtrema<T>::max;
constexpr static const char *name = detail::type_name<T>();
};
#if FP4_TYPE_SUPPORTED
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break;
#else
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -412,6 +468,7 @@ struct TypeInfo {
using type = byte; \
{ __VA_ARGS__ } \
} break; \
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -523,6 +580,9 @@ struct TypeInfo {
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -593,6 +653,14 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
template <typename T>
struct is_fp4 : std::false_type {};
#if FP4_TYPE_SUPPORTED
template <>
struct is_fp4<fp4e2m1> : std::true_type {};
#endif
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
......@@ -611,13 +679,16 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
}
size_t typeToSize(const DType type);
size_t typeToNumBits(const DType type);
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
......@@ -636,7 +707,7 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
const uint32_t offset_elems, const size_t type_num_bits);
bool is_supported_by_CC_100();
......
......@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype());
int hidden_size_in_bytes = (num_heads * dim_per_head * typeToNumBits(tensor.dtype())) / 8;
// For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
......@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0);
NVTE_CHECK(((hidden_size * typeToNumBits(grad.dtype())) / 8) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
......
......@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
......@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
......@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
......@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
......
......@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrQ = static_cast<void*>(devPtrQKV);
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
......@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim;
stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrQ = devPtrQKV;
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
......@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
......@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim;
stride = (typeToNumBits(QKV_type) * head_dim) / 8;
}
void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
......
......@@ -22,17 +22,18 @@ extern "C" {
* \brief TE datatype.
*/
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */
kNVTEByte = 0, /*!< Byte */
kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTENumTypes /*!< Number of supported types */
};
/*! \struct NVTEShape
......@@ -87,6 +88,10 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
NVTE_INVALID_SCALING = 100
};
......@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
*/
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get the byte size for the tensor.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor.
*/
size_t nvte_tensor_size_bytes(const NVTETensor tensor);
/*! \brief Get a tensor's total number of elements.
*
* \param[in] tensor Tensor.
......@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor);
*/
size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get the bit size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Bit size of the tensor's data type.
*/
size_t nvte_tensor_element_size_bits(const NVTETensor tensor);
/*! \brief Get a tensor's data type.
*
* \param[in] tensor Tensor.
......@@ -390,6 +411,7 @@ enum class DType {
kFloat8E4M3 = 7,
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kFloat4E2M1 = 10,
kNumTypes
};
......@@ -398,7 +420,16 @@ enum class DType {
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
inline bool is_fp8_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2;
}
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
......@@ -627,6 +658,15 @@ class TensorWrapper {
return nvte_tensor_element_size(tensor_);
}
/*! \brief Get the tensor's element size in bits.
*
* \return Element size in bits.
*/
size_t element_size_bits() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size_bits(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype.
*
......@@ -634,7 +674,7 @@ class TensorWrapper {
*/
size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_);
return nvte_tensor_size_bytes(tensor_);
}
/*! \brief Get the data type of this TensorWrapper.
......
......@@ -212,8 +212,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype;
NVTE_CHECK(gamma_dtype == DType::kFloat32 || gamma_dtype == DType::kFloat16 ||
gamma_dtype == DType::kBFloat16,
"Gamma of type FP4 is not supported");
_scalar_dptr = std::make_unique<char[]>(typeToSize(gamma_dtype));
_scalar_dptr = std::make_unique<char[]>(typeToNumBits(gamma_dtype) / 8);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gamma_dtype, cpp_dtype,
*(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
......
......@@ -18,12 +18,15 @@
namespace transformer_engine {
size_t typeToSize(const DType type) {
size_t typeToNumBits(const DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;); // NOLINT(*)
}
bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; }
size_t typeToSize(const DType type) {
NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
return typeToNumBits(type) / 8;
}
std::string to_string(const DType type) {
switch (type) {
......@@ -41,6 +44,8 @@ std::string to_string(const DType type) {
return "Float8E5M2";
case DType::kFloat8E8M0:
return "Float8E8M0";
case DType::kFloat4E2M1:
return "Float4E2M1";
case DType::kInt32:
return "Int32";
case DType::kInt64:
......@@ -56,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING";
}
......@@ -85,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")");
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_colwise = 32;
if (t.has_data()) {
alignment = block_alignment[0];
......@@ -96,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(32)), alignment) * alignment;
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
......@@ -105,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
if (t.has_columnwise_data()) {
alignment = block_alignment[1];
expected_x =
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(32)), alignment) * alignment;
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
......@@ -384,10 +396,24 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
return numel;
}
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 8 * sizeof(float);
return transformer_engine::typeToNumBits(t->dtype());
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return sizeof(float);
return transformer_engine::typeToSize(t->dtype());
NVTE_CHECK(!is_fp4_dtype(t->dtype()),
"For FP4 type please use the nvte_tensor_element_size_bits.");
return nvte_tensor_element_size_bits(tensor) / 8;
}
size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 0;
return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
}
void *nvte_tensor_data(const NVTETensor tensor) {
......@@ -514,7 +540,7 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor);
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
}
// Set amax to 0 if allocated
......
......@@ -192,17 +192,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
std::multiplies<size_t>()),
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
}
}
......
......@@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype);
const int tile_dim_m = THREADS_PER_WARP * desired_store_size * 8 / typeToNumBits(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size * 8 / typeToNumBits(itype);
// Add tensors to kernel argument struct
MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned;
......
......@@ -460,7 +460,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
CUtensorMap tensor_map_output_trans{};
create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x,
/*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM,
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType));
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8);
return tensor_map_output_trans;
}
......
......@@ -382,17 +382,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
std::multiplies<size_t>()),
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
}
}
......
......@@ -754,19 +754,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, sizeof(IType));
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType));
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType));
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
......@@ -849,31 +850,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0,
typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
sizeof(OType));
typeToNumBits(output->dtype()));
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, sizeof(OType));
0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, sizeof(OType));
cols, typeToNumBits(output->dtype()));
}
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
......
......@@ -895,15 +895,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
}
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(OType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
......@@ -991,24 +991,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(IType));
typeToNumBits(input.dtype()));
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
}
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
......@@ -1101,7 +1101,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16;
const int alignment_requirement = TMA_bytes / typeToSize(t->dtype());
const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
......
......@@ -319,9 +319,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(OType));
SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype()));
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr,
......
......@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment