Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7405fe09
Commit
7405fe09
authored
May 22, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.3'
parents
7462e0e4
c636071d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
19 deletions
+22
-19
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+2
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+10
-10
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+2
-2
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+2
-2
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+4
-2
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+2
-2
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
7405fe09
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <cassert>
#include <cassert>
#include <numeric>
#include <numeric>
#include "amd_detail/hip_float8.h"
#include "common/common.h"
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
...
@@ -330,7 +331,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -330,7 +331,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert
(
rs_output
.
element_size
()
==
2
);
assert
(
rs_output
.
element_size
()
==
2
);
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
#ifdef USE_ROCM
#ifdef USE_ROCM
reducescatter2_userbuff_fp8
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
reducescatter2_userbuff_fp8
<
te_
hip_f
p8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#else
#else
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
#endif
#endif
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
7405fe09
...
@@ -2034,12 +2034,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
...
@@ -2034,12 +2034,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
te_
hip_f
p8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
hip_f
8
<
hip_f8_type
::
fp8
>
>
(
template
void
reducescatter2_userbuff_stridedoutput_fp8
<
te_
hip_f
p8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
cudaEvent_t
comm_launch_event
);
...
@@ -2052,30 +2052,30 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler,
...
@@ -2052,30 +2052,30 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler,
comm
,
stream
,
comm_launch_event
);
comm
,
stream
,
comm_launch_event
);
}
}
template
void
reducescatter2_userbuff_fp8
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
void
*
output
,
float
*
scale
,
template
void
reducescatter2_userbuff_fp8
<
te_
hip_f
p8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_fp8
<
hip_f
8
<
hip_f8_type
::
fp8
>
>
(
void
*
output
,
float
*
scale
,
template
void
reducescatter2_userbuff_fp8
<
te_
hip_f
p8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
handler
,
const
int
offset
,
const
int
elements
,
communicator
*
comm
,
const
int
elements
,
communicator
*
comm
,
cudaStream_t
stream
,
cudaStream_t
stream
,
cudaEvent_t
comm_launch_event
);
cudaEvent_t
comm_launch_event
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
hip_f
8
<
hip_f8_type
::
fp8
>
>
(
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
te_
hip_f
p8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
template
void
reducescatter2_userbuff_strided_atomic_fp8
<
te_
hip_f
p8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
hip_f
8
<
hip_f8_type
::
fp8
>
>
(
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
te_
hip_f
p8_e4m3
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
template
void
reducescatter2_userbuff_strided_multiatomic_fp8
<
te_
hip_f
p8_e5m2
>(
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
void
*
output
,
float
*
scale
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
colelements
,
const
int
strideelements_out
,
const
int
strideelements_in
,
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
const
int
numchunks
,
void
*
counters
,
communicator
*
comm
,
cudaStream_t
stream
);
...
@@ -2845,10 +2845,10 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
...
@@ -2845,10 +2845,10 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
template
void
reduce_fp8_in_bf16_out
<
hip_f
8
<
hip_f8_type
::
fp8
>
>
(
void
*
inputs
,
void
*
output
,
float
*
scale
,
template
void
reduce_fp8_in_bf16_out
<
te_
hip_f
p8_e4m3
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
cudaStream_t
stream
);
template
void
reduce_fp8_in_bf16_out
<
hip_f
8
<
hip_f8_type
::
bf8
>
>
(
void
*
inputs
,
void
*
output
,
float
*
scale
,
template
void
reduce_fp8_in_bf16_out
<
te_
hip_f
p8_e5m2
>(
void
*
inputs
,
void
*
output
,
float
*
scale
,
int
num_inputs
,
int
input_size
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
);
cudaStream_t
stream
);
#else
#else
...
...
transformer_engine/common/normalization/common.h
View file @
7405fe09
...
@@ -334,8 +334,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
...
@@ -334,8 +334,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
#else
using
bf16
=
__hip_bfloat16
;
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
hip_f
8
<
hip_f8_type
::
fp8
>
;
using
fp8e4m3
=
te_
hip_f
p8_e4m3
;
using
fp8e5m2
=
hip_f
8
<
hip_f8_type
::
bf8
>
;
using
fp8e5m2
=
te_
hip_f
p8_e5m2
;
#endif
#endif
template
<
typename
T
>
template
<
typename
T
>
...
...
transformer_engine/common/permutation/permutation.cu
View file @
7405fe09
...
@@ -11,8 +11,8 @@
...
@@ -11,8 +11,8 @@
#include "../common.h"
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
using
__nv_fp8_e4m3
=
hip_f
8
<
hip_f8_type
::
fp8
>
;
using
__nv_fp8_e4m3
=
te_
hip_f
p8_e4m3
;
using
__nv_fp8_e5m2
=
hip_f
8
<
hip_f8_type
::
bf8
>
;
using
__nv_fp8_e5m2
=
te_
hip_f
p8_e5m2
;
#define __ldlu(x) __ldg(x)
#define __ldlu(x) __ldg(x)
#endif
#endif
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
7405fe09
...
@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
...
@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// broadcast the amax to all threads in a warp from the lane 0
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
,
THREADS_PER_WARP
);
__syncthreads
();
#else
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
#endif
...
@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
...
@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
// broadcast the amax to all threads in a warp from the lane 0
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl
(
warp_tile_amax
,
lane_zero
,
THREADS_PER_WARP
);
__syncthreads
();
#else
#else
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
#endif
#endif
...
...
transformer_engine/common/utils.cuh
View file @
7405fe09
...
@@ -986,8 +986,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
...
@@ -986,8 +986,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
#else
using
fp8e4m3
=
hip_f
8
<
hip_f8_type
::
fp8
>
;
using
fp8e4m3
=
te_
hip_f
p8_e4m3
;
using
fp8e5m2
=
hip_f
8
<
hip_f8_type
::
bf8
>
;
using
fp8e5m2
=
te_
hip_f
p8_e5m2
;
#endif
#endif
using
e8m0_t
=
uint8_t
;
using
e8m0_t
=
uint8_t
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment