Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b8fe26e7
Commit
b8fe26e7
authored
May 13, 2025
by
yuguo
Browse files
[DCU] surpport blockwise fp8 quantize
parent
ab3e5a92
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
71 additions
and
9 deletions
+71
-9
hipify_custom_map.json
hipify_custom_map.json
+4
-0
setup.py
setup.py
+4
-4
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+1
-1
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+2
-2
transformer_engine/common/common.h
transformer_engine/common/common.h
+4
-0
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+28
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+26
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+1
-1
No files found.
hipify_custom_map.json
View file @
b8fe26e7
{
{
"custom_map"
:
{
"custom_map"
:
{
"common/utils.cuh"
:
"common/utils_hip.cuh"
,
"common/transpose/cast_transpose.h"
:
"common/transpose/cast_transpose_hip.h"
,
"common/recipe/recipe_common.cuh"
:
"common/recipe/recipe_common_hip.cuh"
,
"common/util/ptx.cuh"
:
"common/util/ptx_hip.cuh"
,
"common/util/vectorized_pointwise.h"
:
"common/util/vectorized_pointwise_hip.h"
,
"common/util/vectorized_pointwise.h"
:
"common/util/vectorized_pointwise_hip.h"
,
"common/common.h"
:
"common/common_hip.h"
,
"common/common.h"
:
"common/common_hip.h"
,
"/userbuffers.h"
:
"/userbuffers_hip.h"
,
"/userbuffers.h"
:
"/userbuffers_hip.h"
,
...
...
setup.py
View file @
b8fe26e7
...
@@ -131,10 +131,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
...
@@ -131,10 +131,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if
not
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
not
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
"pytorch"
in
frameworks
:
if
"pytorch"
in
frameworks
:
install_reqs
.
extend
([
"torch>=2.1"
])
install_reqs
.
extend
([
"torch>=2.1"
])
install_reqs
.
append
(
#
install_reqs.append(
"nvdlfw-inspect @"
#
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
#
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
#
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
# install_reqs.append("triton")
test_reqs
.
extend
([
"numpy"
,
"torchvision"
,
"prettytable"
,
"PyYAML"
])
test_reqs
.
extend
([
"numpy"
,
"torchvision"
,
"prettytable"
,
"PyYAML"
])
...
...
tests/cpp/CMakeLists.txt
View file @
b8fe26e7
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
#
CXX=hipcc make
build && cd build && cmake ../
#
mkdir
build && cd build &&
CXX=hipcc
cmake ../
cmake_minimum_required
(
VERSION 3.18
)
cmake_minimum_required
(
VERSION 3.18
)
option
(
USE_CUDA
"Use CUDA"
ON
)
option
(
USE_CUDA
"Use CUDA"
ON
)
...
...
tests/cpp/operator/CMakeLists.txt
View file @
b8fe26e7
...
@@ -11,7 +11,7 @@ list(APPEND test_cuda_sources
...
@@ -11,7 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8.cu
#
test_cast_float8blockwise.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose.cu
...
...
transformer_engine/common/CMakeLists.txt
View file @
b8fe26e7
...
@@ -163,8 +163,8 @@ else()
...
@@ -163,8 +163,8 @@ else()
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
#
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
#
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
...
...
transformer_engine/common/common.h
View file @
b8fe26e7
...
@@ -294,7 +294,11 @@ struct TypeExtrema;
...
@@ -294,7 +294,11 @@ struct TypeExtrema;
template
<
>
template
<
>
struct
TypeExtrema
<
fp8e4m3
>
{
struct
TypeExtrema
<
fp8e4m3
>
{
#ifndef __HIP_PLATFORM_AMD__
static
constexpr
float
max
=
448.0
f
;
static
constexpr
float
max
=
448.0
f
;
#else
static
constexpr
float
max
=
240.0
f
;
#endif
};
};
template
<
>
template
<
>
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
b8fe26e7
...
@@ -5,12 +5,16 @@
...
@@ -5,12 +5,16 @@
************************************************************************/
************************************************************************/
#include <cuda.h>
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#include <cuda/barrier>
#endif
#include "common/common.h"
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/recipe/recipe_common.cuh"
...
@@ -69,7 +73,9 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
...
@@ -69,7 +73,9 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
#ifndef __HIP_PLATFORM_AMD__
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
#endif
bool
pow_2_scaling
)
{
bool
pow_2_scaling
)
{
using
IVec
=
Vec
<
IType
,
THREAD_TILE_DIM_X
>
;
using
IVec
=
Vec
<
IType
,
THREAD_TILE_DIM_X
>
;
using
OVecCast
=
Vec
<
OType
,
THREAD_TILE_DIM_X
>
;
using
OVecCast
=
Vec
<
OType
,
THREAD_TILE_DIM_X
>
;
...
@@ -128,7 +134,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
...
@@ -128,7 +134,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
// broadcast the amax to all threads in a warp from the lane 0
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if
(
tid_in_warp
==
0
)
{
if
(
tid_in_warp
==
0
)
{
...
@@ -351,7 +361,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
...
@@ -351,7 +361,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
// broadcast the amax to all threads in a warp from the lane 0
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if
(
tid_in_warp
==
0
)
{
if
(
tid_in_warp
==
0
)
{
...
@@ -447,6 +461,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
...
@@ -447,6 +461,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
}
}
#ifndef __HIP_PLATFORM_AMD__
template
<
typename
OutputType
>
template
<
typename
OutputType
>
CUtensorMap
get_tensor_map
(
const
SimpleTensor
&
tensor
,
size_t
global_dim_x
,
size_t
global_dim_y
)
{
CUtensorMap
get_tensor_map
(
const
SimpleTensor
&
tensor
,
size_t
global_dim_x
,
size_t
global_dim_y
)
{
CUtensorMapDataType
dataType
;
CUtensorMapDataType
dataType
;
...
@@ -463,6 +478,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
...
@@ -463,6 +478,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
/*stride_elems=*/
global_dim_x
,
/*offset_elems=*/
0
,
sizeof
(
OutputType
));
/*stride_elems=*/
global_dim_x
,
/*offset_elems=*/
0
,
sizeof
(
OutputType
));
return
tensor_map_output_trans
;
return
tensor_map_output_trans
;
}
}
#endif
}
// namespace
}
// namespace
}
// namespace transformer_engine
}
// namespace transformer_engine
...
@@ -526,6 +542,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -526,6 +542,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
row_length
%
BLOCK_TILE_DIM
==
0
&&
num_rows
%
BLOCK_TILE_DIM
==
0
;
row_length
%
BLOCK_TILE_DIM
==
0
&&
num_rows
%
BLOCK_TILE_DIM
==
0
;
if
(
full_tile
)
{
if
(
full_tile
)
{
#ifndef __HIP_PLATFORM_AMD__
CUtensorMap
tensor_map_output_trans
;
CUtensorMap
tensor_map_output_trans
;
if
(
return_transpose
)
{
if
(
return_transpose
)
{
tensor_map_output_trans
=
tensor_map_output_trans
=
...
@@ -540,6 +557,17 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -540,6 +557,17 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
tensor_map_output_trans
,
pow_2_scale
);
#else
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
pow_2_scale
);
#endif
}
else
{
}
else
{
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
OutputType
>
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
b8fe26e7
...
@@ -5,13 +5,17 @@
...
@@ -5,13 +5,17 @@
************************************************************************/
************************************************************************/
#include <cuda.h>
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <algorithm>
#include <cfloat>
#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#include <cuda/barrier>
#endif
#include <utility>
#include <utility>
#include "common/common.h"
#include "common/common.h"
...
@@ -252,12 +256,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -252,12 +256,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 2.3: Reduce amax
// Step 2.3: Reduce amax
#pragma unroll
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down
(
amax
,
delta
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl
(
amax
,
src_lane
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
CType
scale
;
CType
scale
;
// Step 2.4: Compute scale
// Step 2.4: Compute scale
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
@@ -341,12 +353,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -341,12 +353,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 3.3: Reduce amax
// Step 3.3: Reduce amax
#pragma unroll
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down
(
amax
,
delta
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl
(
amax
,
src_lane
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
// Step 3.4: Compute scale
// Step 3.4: Compute scale
CType
scale
;
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
@@ -472,9 +492,15 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -472,9 +492,15 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
// shared memory must be requested up
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
if
(
smem_bytes
>=
48
*
1024
)
{
#ifdef __HIP_PLATFORM_AMD__
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
#else
cudaError_t
err
=
cudaFuncSetAttribute
(
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
#endif
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
...
...
transformer_engine/pytorch/module/base.py
View file @
b8fe26e7
...
@@ -180,7 +180,7 @@ def initialize_ub(
...
@@ -180,7 +180,7 @@ def initialize_ub(
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
"""
"""
if
not
tex
.
device_supports_multicast
():
if
not
tex
.
device_supports_multicast
():
assert
bool
(
int
(
os
.
getenv
(
"UB_SKIPMC"
,
"
0
"
))),
(
assert
bool
(
int
(
os
.
getenv
(
"UB_SKIPMC"
,
"
1
"
))),
(
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
"CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment