Unverified Commit c0df246a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move FP8 block scaling kernels to core (#1730)



* Move FP8 block scaling kernels to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix symbol error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix arg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent afb70224
...@@ -90,6 +90,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -90,6 +90,7 @@ list(APPEND transformer_engine_SOURCES
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/userbuffers/userbuffers.cu
......
...@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s ...@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream); cudaStream_t stream);
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -4,10 +4,15 @@ ...@@ -4,10 +4,15 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common/common.h" #include <transformer_engine/recipe.h>
#include "common/utils.cuh"
#include "extensions.h" #include <cassert>
#include "type_shim.h"
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace fp8_block_scaling_recipe {
constexpr int kTileDim = 128; constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256; constexpr int kThreadsPerBlock = 256;
...@@ -153,18 +158,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -153,18 +158,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
} }
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t w, size_t start_offset, size_t block_len) { size_t amax_stride_h, size_t amax_stride_w,
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128"); size_t start_offset, size_t block_len,
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); cudaStream_t stream) {
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
size_t amax_stride_h = amax.stride(0); size_t len = inp.numel();
size_t amax_stride_w = amax.stride(1);
size_t len = tensor.numel();
assert(h > 0 && w > 0); assert(h > 0 && w > 0);
assert(start_offset < h * w); assert(start_offset < h * w);
...@@ -176,31 +176,21 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor ...@@ -176,31 +176,21 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
assert(blocks_y <= std::numeric_limits<unsigned int>::max()); assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y); dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream(); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor.scalar_type(), 0, "compute_partial_amax", fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
fp8_block_scaling_compute_partial_amax_kernel<scalar_t_0> <<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr),
<<<grid, kThreadsPerBlock, 0, stream>>>( reinterpret_cast<float *>(amax.data.dptr),
tensor.data_ptr<scalar_t_0>(), amax.data_ptr<float>(), amax_stride_h, amax_stride_w, h, w, start_offset,
amax_stride_h, amax_stride_w, h, w, start_offset, len);) len);)
} }
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t h, size_t w, size_t start_offset, size_t block_len, size_t w, size_t scale_stride_h, size_t scale_stride_w,
const transformer_engine::DType out_dtype) { size_t start_offset, size_t block_len, const DType out_dtype,
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128"); cudaStream_t stream) {
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
size_t scale_stride_h = scale.stride(0);
size_t scale_stride_w = scale.stride(1);
size_t len = inp.numel(); size_t len = inp.numel();
assert(h > 0 && w > 0); assert(h > 0 && w > 0);
...@@ -213,17 +203,43 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const ...@@ -213,17 +203,43 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
assert(blocks_y <= std::numeric_limits<unsigned int>::max()); assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y); dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream(); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
DISPATCH_FLOAT_HALF_AND_BFLOAT(
inp.scalar_type(), 0, "partial_cast",
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
out_dtype, fp8_type, out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned, w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<scalar_t_0, fp8_type, kWidthAligned> fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(inp.data_ptr<scalar_t_0>(), <<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<fp8_type *>(out.data_ptr()), reinterpret_cast<const inp_dtype *>(inp.data.dptr),
scale.data_ptr<float>(), scale_stride_h, reinterpret_cast<fp8_type *>(out.data.dptr),
scale_stride_w, h, w, start_offset, len);))) reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);)))
}
} // namespace fp8_block_scaling_recipe
} // namespace transformer_engine
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
} }
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor);
TensorWrapper amax_cu = makeTransformerEngineTensor(amax);
nvte_fp8_block_scaling_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w,
amax.stride(0), amax.stride(1), start_offset,
block_len, at::cuda::getCurrentCUDAStream());
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out);
const TensorWrapper scale_cu = makeTransformerEngineTensor(scale);
nvte_fp8_block_scaling_partial_cast(
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment