Unverified Commit 7b94bd99 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[common] Added support of FP4 data type (#1779)



* Added support of FP4 data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Refactoring to BitsNum in progress
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed compilation errors. All C++ tests passed
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixed a typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added FP4 guard to TMA tensor descriptor data type
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed errors in JAX C++ extensions
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed dummy NVFP4 C++ test file
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Make pytorch changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Refactored the code per the review notes. Fixed JAX build error.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Removed unnecessary static casts
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Typo fix
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Pass correct num bits to create_2D_tensor_map; fixes CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* inline funcs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e963e4a9
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
#include <cstdint> #include <cstdint>
#else #else
......
...@@ -221,21 +221,23 @@ std::vector<size_t> getTensorShape(at::Tensor t); ...@@ -221,21 +221,23 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe); const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) { inline size_t typeToNumBits(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt64: case transformer_engine::DType::kInt64:
return 8; return 64;
case transformer_engine::DType::kInt32: case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32: case transformer_engine::DType::kFloat32:
return 4; return 32;
case transformer_engine::DType::kInt16: case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16: case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16: case transformer_engine::DType::kBFloat16:
return 2; return 16;
case transformer_engine::DType::kByte: case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2: case transformer_engine::DType::kFloat8E5M2:
return 1; return 8;
case transformer_engine::DType::kFloat4E2M1:
return 4;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
......
...@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -24,12 +24,12 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype()); size_t element_size_bits = transformer_engine::pytorch::typeToNumBits(self.dtype());
int32_t start_row = start_index.data_ptr<int32_t>()[0]; int32_t start_row = start_index.data_ptr<int32_t>()[0];
void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) + void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) +
static_cast<size_t>(start_row) * fcd_size * element_size; static_cast<size_t>(start_row) * fcd_size * element_size_bits / 8;
size_t num_rows_to_zero = max_tokens - start_row; size_t num_rows_to_zero = max_tokens - start_row;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size; size_t total_bytes = num_rows_to_zero * fcd_size * element_size_bits / 8;
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE(
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment