"vscode:/vscode.git/clone" did not exist on "c267b1a02c952b68a897c96201f32ad57e0b955e"
Unverified Commit c0df246a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move FP8 block scaling kernels to core (#1730)



* Move FP8 block scaling kernels to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix symbol error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix arg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent afb70224
......@@ -90,6 +90,7 @@ list(APPEND transformer_engine_SOURCES
fused_rope/fused_rope.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
......
......@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream);
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -4,10 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#include "common/common.h"
#include "common/utils.cuh"
#include "extensions.h"
#include "type_shim.h"
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace fp8_block_scaling_recipe {
constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256;
......@@ -153,18 +158,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t amax_stride_h = amax.stride(0);
size_t amax_stride_w = amax.stride(1);
size_t len = tensor.numel();
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
......@@ -176,31 +176,21 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor.scalar_type(), 0, "compute_partial_amax",
fp8_block_scaling_compute_partial_amax_kernel<scalar_t_0>
<<<grid, kThreadsPerBlock, 0, stream>>>(
tensor.data_ptr<scalar_t_0>(), amax.data_ptr<float>(),
amax_stride_h, amax_stride_w, h, w, start_offset, len);)
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset,
len);)
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
size_t scale_stride_h = scale.stride(0);
size_t scale_stride_w = scale.stride(1);
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t w, size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len, const DType out_dtype,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t len = inp.numel();
assert(h > 0 && w > 0);
......@@ -213,17 +203,43 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(
inp.scalar_type(), 0, "partial_cast",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<scalar_t_0, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(inp.data_ptr<scalar_t_0>(),
reinterpret_cast<fp8_type *>(out.data_ptr()),
scale.data_ptr<float>(), scale_stride_h,
scale_stride_w, h, w, start_offset, len);)))
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);)))
}
} // namespace fp8_block_scaling_recipe
} // namespace transformer_engine
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
const TensorWrapper tensor_cu = makeTransformerEngineTensor(tensor);
TensorWrapper amax_cu = makeTransformerEngineTensor(amax);
nvte_fp8_block_scaling_compute_partial_amax(tensor_cu.data(), amax_cu.data(), h, w,
amax.stride(0), amax.stride(1), start_offset,
block_len, at::cuda::getCurrentCUDAStream());
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out);
const TensorWrapper scale_cu = makeTransformerEngineTensor(scale);
nvte_fp8_block_scaling_partial_cast(
inp_cu.data(), out_cu.data(), scale_cu.data(), h, w, scale.stride(0), scale.stride(1),
start_offset, block_len, static_cast<NVTEDType>(out_dtype), at::cuda::getCurrentCUDAStream());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment