Unverified Commit 877b7966 authored by Jianbing's avatar Jianbing Committed by GitHub
Browse files

Feature fast cast-only mxfp8 (#2062)



* refactor mxfp8_cast_only kernel
Signed-off-by: default avatarJianbing Dong <jianbingd@nvidia.com>

* fix ptx.cuh after format
Signed-off-by: default avatarJianbing Dong <jianbingd@nvidia.com>

---------
Signed-off-by: default avatarJianbing Dong <jianbingd@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
parent 41fb9bcf
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../../util/ptx.cuh" #include "../../util/ptx.cuh"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "../core/common.cuh" #include "../core/common.cuh"
#include "specialized/quantize_mxfp8.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace dispatch { namespace dispatch {
...@@ -619,6 +620,73 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, ...@@ -619,6 +620,73 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
if (specialized::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>()) {
switch (scaling_type) {
case ScalingType::ROWWISE: {
using traits = specialized::CastTraits<IType, OType, true, false>;
auto kernel = specialized::quantize_mxfp8_kernel_cast_only<traits>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
traits::smem);
dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M);
dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN,
(rows + traits::blockDimM - 1) / traits::blockDimM);
kernel<<<grid, block, traits::smem, stream>>>(
reinterpret_cast<typename traits::IType *>(input.data.dptr),
reinterpret_cast<typename traits::OType *>(output->data.dptr),
scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
break;
}
case ScalingType::COLWISE: {
NVTE_WARN("Colwise scaling will fallback to original kernel.");
break;
}
case ScalingType::BIDIMENSIONAL: {
using traits = specialized::CastTraits<IType, OType, true, true>;
auto kernel = specialized::quantize_mxfp8_kernel_cast_only<traits>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
traits::smem);
// TMA for loading, so that we don't need STS for transposing
alignas(64) CUtensorMap tensor_map_input{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
create_2D_tensor_map(tensor_map_input, input.data, rows, cols,
traits::blockIterDim::M, traits::blockIterDim::N,
/*stride_elems=*/cols,
/*offset_elems=*/0, input_type_bit_size,
traits::input_swizzle_pattern);
alignas(64) CUtensorMap tensor_map_rowwise_output{};
alignas(64) CUtensorMap tensor_map_colwise_output{};
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols,
traits::blockIterDim::M, traits::blockIterDim::N,
/*stride_elems=*/cols,
/*offset_elems=*/0, output_type_bit_size,
traits::output_swizzle_pattern);
create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, cols,
traits::blockIterDim::M, traits::blockIterDim::N, cols, 0,
output_type_bit_size, traits::output_swizzle_pattern);
dim3 block(traits::rowThreadLayout::num, traits::numWarps);
dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N,
(rows + traits::blockDIM::M - 1) / traits::blockDIM::M);
kernel<<<grid, block, traits::smem, stream>>>(
tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
default: {
NVTE_ERROR("Invalid scaling type.");
}
}
return;
}
alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{}; alignas(64) CUtensorMap tensor_map_output_rowwise{};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file state_counter.cuh
* \brief CUDA kernels to count state.
*/
#ifndef TRANSFORMER_ENGINE_SPECIALIZED_STATE_COUNTER_CUH_
#define TRANSFORMER_ENGINE_SPECIALIZED_STATE_COUNTER_CUH_
#include <cstdint>
namespace transformer_engine {
template <int32_t numStages, bool Flip = false>
struct PipeState {
int2 _storage; // x: index, y: phase
__device__ __forceinline__ PipeState() : _storage{0, 0} {
if constexpr (Flip) {
_storage.y ^= 1;
}
}
__device__ __forceinline__ int32_t index() const { return _storage.x; }
__device__ __forceinline__ int32_t phase() const { return _storage.y; }
__device__ __forceinline__ void operator++(int32_t) {
if constexpr (numStages > 0) {
_storage.x++;
if (_storage.x == numStages) {
_storage.x = 0;
_storage.y ^= 1;
}
}
}
};
template <int32_t numStages>
struct PipeStateCounter {
int32_t _counter;
__device__ __forceinline__ PipeStateCounter() : _counter(0) {}
__device__ __forceinline__ int32_t index() const { return _counter; }
__device__ __forceinline__ void operator++(int32_t) {
if constexpr (numStages > 0) {
_counter++;
_counter = _counter == numStages ? 0 : _counter;
}
}
};
} // namespace transformer_engine
#endif // #ifndef TRANSFORMER_ENGINE_SPECIALIZED_STATE_COUNTER_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file swizzle.cuh
* \brief CUDA kernels to swizzle.
*/
#ifndef TRANSFORMER_ENGINE_SPECIALIZED_SWIZZLE_CUH_
#define TRANSFORMER_ENGINE_SPECIALIZED_SWIZZLE_CUH_
#include <cmath>
#include <cstdint>
namespace transformer_engine {
namespace swz {
template <auto v>
struct C {
using type = C<v>;
static constexpr auto value = v;
using value_type = decltype(v);
__device__ __host__ __forceinline__ constexpr operator value_type() const noexcept {
return value;
}
};
template <class T, T v>
using constant = C<v>;
template <class T, typename Ts, Ts s>
__host__ __device__ __forceinline__ constexpr T shiftr(T x) {
if constexpr (std::is_same_v<Ts, uint32_t>) {
return x >> s;
} else if constexpr (std::is_same_v<Ts, int32_t>) {
if constexpr (s >= 0) {
return x >> s;
} else {
return x << -s;
}
}
}
template <int32_t BBits, int32_t MBase, int32_t SShift>
struct Swizzle {
static constexpr int32_t num_bits = BBits; // number of rows
static constexpr int32_t num_base = MBase; // number of elements within a chunk
static constexpr int32_t num_shft = SShift; // number of columns, at the granularity of a chunk
static_assert(num_base >= 0, "MBase must be non-negative");
static_assert(num_bits >= 0, "BBits must be non-negative");
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be greater than or equal to num_bits");
using bit_mask = constant<int32_t, (1 << num_bits) - 1>;
using yyy_mask =
constant<int32_t, bit_mask{} << (num_base + std::max(decltype(num_shft){0}, num_shft))>;
using zzz_mask =
constant<int32_t, bit_mask{} << (num_base - std::min(decltype(num_shft){0}, num_shft))>;
using msk_shft = constant<int32_t, num_shft>;
static constexpr int32_t swz_code = int32_t(yyy_mask{} | zzz_mask{});
template <class Offset>
__host__ __device__ __forceinline__ constexpr static int32_t apply(Offset const &offset) {
return offset ^
shiftr<Offset, typename msk_shft::value_type, msk_shft::value>(offset & yyy_mask{});
}
__host__ __device__ __forceinline__ constexpr static int32_t swz(int32_t const &offset) {
return apply(offset);
}
};
struct Linear {
template <class Offset>
__host__ __device__ __forceinline__ constexpr static int32_t apply(Offset const &offset) {
return offset;
}
__host__ __device__ __forceinline__ constexpr static int32_t swz(int32_t const &offset) {
return offset;
}
};
} // namespace swz
} // namespace transformer_engine
#endif // #ifndef TRANSFORMER_ENGINE_SPECIALIZED_SWIZZLE_CUH_
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment