Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
85ea3dd2
Commit
85ea3dd2
authored
Dec 31, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.10' into release_v2.10
parents
aaab9f18
cb2fe806
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
34 deletions
+64
-34
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+1
-1
transformer_engine/common/common.h
transformer_engine/common/common.h
+2
-8
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+4
-3
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+41
-14
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+16
-8
No files found.
transformer_engine/common/CMakeLists.txt
View file @
85ea3dd2
...
@@ -348,7 +348,7 @@ else()
...
@@ -348,7 +348,7 @@ else()
comm_gemm_overlap/userbuffers/userbuffers.cu
)
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
list
(
APPEND transformer_engine_cuda_arch_specific_sources
util
/cast.cu
cast
/cast.cu
activation/gelu.cu
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
...
...
transformer_engine/common/common.h
View file @
85ea3dd2
...
@@ -705,6 +705,7 @@ struct TypeInfo {
...
@@ -705,6 +705,7 @@ struct TypeInfo {
using type = bf16; \
using type = bf16; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
} break; \
} break; \
case DType::kInt8: \
case DType::kFloat8E5M2: \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
NVTE_ERROR("FP8 type not instantiated for input."); \
...
@@ -712,10 +713,6 @@ struct TypeInfo {
...
@@ -712,10 +713,6 @@ struct TypeInfo {
case DType::kFloat4E2M1: { \
case DType::kFloat4E2M1: { \
NVTE_ERROR("FP4 type not instantiated for input."); \
NVTE_ERROR("FP4 type not instantiated for input."); \
} break; \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type."); \
}
}
...
@@ -735,14 +732,11 @@ struct TypeInfo {
...
@@ -735,14 +732,11 @@ struct TypeInfo {
using type = bf16; \
using type = bf16; \
{ __VA_ARGS__ } \
{ __VA_ARGS__ } \
} break; \
} break; \
case DType::kInt8: \
case DType::kFloat8E5M2: \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
default: \
NVTE_ERROR("Invalid type."); \
NVTE_ERROR("Invalid type."); \
}
}
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
85ea3dd2
...
@@ -1425,13 +1425,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
...
@@ -1425,13 +1425,14 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n
.
push_back
(
B0
);
n
.
push_back
(
B0
);
}
}
}
}
bool
use_bias
=
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
?
true
:
false
;
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
if
((
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
)
)
{
if
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
)
{
NVTE_ERROR
(
"MOE nvte_grouped_gemm not surpport
bias or
gelu."
);
NVTE_ERROR
(
"MOE nvte_grouped_gemm not surpport gelu."
);
}
}
hipblaslt_groupedgemm
(
inputA
,
inputB
,
outputD
,
m
,
n
,
k
,
b
,
hipblaslt_groupedgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
use_bias
,
grad
,
m
,
n
,
k
,
b
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
85ea3dd2
...
@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
...
@@ -362,9 +362,9 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
return
val
;
return
val
;
}
}
template
<
typename
InputType
>
template
<
typename
InputType
,
typename
OutputType
>
__launch_bounds__
(
1024
)
__global__
__launch_bounds__
(
1024
)
__global__
void
bias_gradient_kernel_v2
(
float
*
dst
,
const
InputType
*
src
,
int
M
,
int
N
)
{
void
bias_gradient_kernel_v2
(
OutputType
*
dst
,
const
InputType
*
src
,
int
M
,
int
N
)
{
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
__shared__
float
g_shared
[
kColwiseReduceTileSize
][
kColwiseReduceTileSize
];
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
grad_sum
=
0.
f
;
float
grad_sum
=
0.
f
;
...
@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
...
@@ -380,7 +380,7 @@ __launch_bounds__(1024) __global__
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
const
int
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
if
(
j
<
N
)
{
dst
[
j
]
=
static_cast
<
float
>
(
sum
);
dst
[
j
]
=
static_cast
<
OutputType
>
(
sum
);
}
}
}
}
}
}
...
@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
...
@@ -409,8 +409,8 @@ __launch_bounds__(1024) __global__
}
}
}
}
template
<
typename
Tin
>
template
<
typename
Tin
,
typename
Tout
>
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
floa
t
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
void
bias_gradient_kernelLauncher
(
const
Tin
*
in
,
Tou
t
*
out
,
int
m
,
int
n
,
bool
stream_order_alloc
,
hipStream_t
stream
)
{
hipStream_t
stream
)
{
dim3
block
,
grid
;
dim3
block
,
grid
;
constexpr
int
THREADS_PER_BLOCK
=
1024
;
constexpr
int
THREADS_PER_BLOCK
=
1024
;
...
@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
...
@@ -418,13 +418,13 @@ void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool
block
.
x
=
THREADS_PER_BLOCK
;
block
.
x
=
THREADS_PER_BLOCK
;
grid
.
x
=
BLOCKS_PER_COL
*
n
;
grid
.
x
=
BLOCKS_PER_COL
*
n
;
if
(
!
stream_order_alloc
)
{
if
(
!
stream_order_alloc
)
{
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
floa
t
)));
NVTE_CHECK_CUDA
(
hipMemset
(
out
,
0
,
n
*
sizeof
(
Tou
t
)));
}
else
{
}
else
{
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
floa
t
),
stream
));
NVTE_CHECK_CUDA
(
hipMemsetAsync
(
out
,
0
,
n
*
sizeof
(
Tou
t
),
stream
));
}
}
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
int
B
=
(
n
-
1
)
/
kColwiseReduceTileSize
+
1
;
bias_gradient_kernel_v2
<
Tin
>
bias_gradient_kernel_v2
<
Tin
,
Tout
>
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
<<<
B
,
dim3
(
kColwiseReduceTileSize
,
kColwiseReduceTileSize
),
0
,
stream
>>>
(
out
,
in
,
m
,
n
);
}
}
...
@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
...
@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
}
}
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
if
(
handle
!=
nullptr
)
if
(
handle
!=
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
}
}
}
}
...
@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
...
@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{
{
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgs
Buffer
&
operator
=
(
const
HipBlasltUserArgs
Buffer
&
)
=
delete
;
HipBlasltUserArgs
Cache
&
operator
=
(
const
HipBlasltUserArgs
Cache
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
{
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
...
@@ -1425,9 +1425,8 @@ struct HipBlasltUserArgsCacheManager {
...
@@ -1425,9 +1425,8 @@ struct HipBlasltUserArgsCacheManager {
std
::
vector
<
HipBlasltUserArgsCache
>
caches_
;
std
::
vector
<
HipBlasltUserArgsCache
>
caches_
;
};
};
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
const
Tensor
*>&
bias
,
bool
use_bias
,
bool
grad
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
void
*
workspace
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
...
@@ -1467,6 +1466,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1467,6 +1466,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
if
(
use_bias
&&
!
grad
)
{
const
hipDataType
bias_type
=
get_hipblaslt_dtype
(
bias
[
0
]
->
data
.
dtype
);
NVTE_CHECK
(
bias_type
==
HIP_R_32F
||
bias_type
==
HIP_R_16BF
);
epilogue
[
0
].
mode
=
HIPBLASLT_EPILOGUE_BIAS
;
epilogue
[
0
].
bias_data_type
=
bias_type
;
}
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
assert
(
m
[
i
]
!=
0
);
assert
(
m
[
i
]
!=
0
);
...
@@ -1477,6 +1483,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1477,6 +1483,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
d
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
d
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
bias
=
bias
[
i
]
->
data
.
dptr
;
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
}
}
...
@@ -1512,6 +1519,26 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1512,6 +1519,26 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
device_user_args
.
setStream
(
stream
);
device_user_args
.
setStream
(
stream
);
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
if
(
use_bias
&&
grad
)
{
DType
input_type
=
inputB
[
0
]
->
data
.
dtype
;
DType
bias_type
=
bias
[
0
]
->
data
.
dtype
;
NVTE_CHECK
(
bias_type
==
DType
::
kFloat32
||
bias_type
==
DType
::
kFloat16
||
bias_type
==
DType
::
kBFloat16
);
for
(
int
i
=
0
;
i
<
m
.
size
();
++
i
)
{
void
*
input_ptr
=
inputB
[
i
]
->
data
.
dptr
;
void
*
bias_ptr
=
bias
[
i
]
->
data
.
dptr
;
int
batch_size
=
static_cast
<
int
>
(
k
[
i
]);
int
output_dim
=
static_cast
<
int
>
(
n
[
i
]);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
input_type
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
bias_type
,
OType
,
detail
::
bias_gradient_kernelLauncher
<
IType
,
OType
>
(
reinterpret_cast
<
const
IType
*>
(
input_ptr
),
reinterpret_cast
<
OType
*>
(
bias_ptr
),
batch_size
,
output_dim
,
true
,
stream
);));
}
}
}
}
#endif //USE_HIPBLASLT
#endif //USE_HIPBLASLT
...
@@ -1738,7 +1765,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1738,7 +1765,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output_dtype
,
OType
,
output_dtype
,
OType
,
detail
::
bias_gradient_kernelLauncher
<
OType
>
(
detail
::
bias_gradient_kernelLauncher
<
OType
,
float
>
(
reinterpret_cast
<
const
OType
*>
(
D
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
reinterpret_cast
<
const
OType
*>
(
D
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
input_dim
,
stream_order_alloc
,
stream
););
input_dim
,
stream_order_alloc
,
stream
););
...
@@ -1808,7 +1835,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1808,7 +1835,7 @@ void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
DType
bias_dtype
=
get_transformer_engine_dtype
(
bias_type
);
DType
bias_dtype
=
get_transformer_engine_dtype
(
bias_type
);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
input_dtype
,
IType
,
input_dtype
,
IType
,
detail
::
bias_gradient_kernelLauncher
<
IType
>
(
detail
::
bias_gradient_kernelLauncher
<
IType
,
float
>
(
reinterpret_cast
<
const
IType
*>
(
B
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
reinterpret_cast
<
const
IType
*>
(
B
),
reinterpret_cast
<
float
*>
(
bias_tmp
),
batch_size
,
output_dim
,
stream_order_alloc
,
stream
););
output_dim
,
stream_order_alloc
,
stream
););
if
(
bias_type
!=
rocblas_datatype_f32_r
)
{
if
(
bias_type
!=
rocblas_datatype_f32_r
)
{
...
...
transformer_engine/common/util/ptx.cuh
View file @
85ea3dd2
...
@@ -71,6 +71,13 @@ constexpr bool is_supported_arch() {
...
@@ -71,6 +71,13 @@ constexpr bool is_supported_arch() {
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
#if CUDA_VERSION < 12090
#if CUDA_VERSION < 12090
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#define __CUDA_ARCH_SPECIFIC__ 900
#define __CUDA_ARCH_SPECIFIC__ 900
...
@@ -90,6 +97,7 @@ constexpr bool is_supported_arch() {
...
@@ -90,6 +97,7 @@ constexpr bool is_supported_arch() {
#endif
#endif
#endif
#endif
#ifdef __CUDA_ARCH__
#ifdef __CUDA_ARCH__
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#else
#else
...
@@ -246,14 +254,6 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
...
@@ -246,14 +254,6 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__
__forceinline__
float
exp2f_rcp
(
e8m0_t
biased_exp
)
{
__device__
__forceinline__
float
exp2f_rcp
(
e8m0_t
biased_exp
)
{
return
(
biased_exp
==
0
)
?
1
return
(
biased_exp
==
0
)
?
1
:
__int_as_float
((
254
-
biased_exp
)
:
__int_as_float
((
254
-
biased_exp
)
...
@@ -265,6 +265,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
...
@@ -265,6 +265,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
}
}
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_DEVICE_ERROR
(
"float_to_e8m0 is not supported on rocm platform."
);
#else
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
if
constexpr
(
is_blackwell
)
{
uint16_t
out
;
uint16_t
out
;
...
@@ -296,6 +299,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
...
@@ -296,6 +299,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
}
}
return
exponent
;
return
exponent
;
}
}
#endif
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...
@@ -407,6 +411,8 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
...
@@ -407,6 +411,8 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
template
<
typename
T
>
template
<
typename
T
>
struct
alignas
(
2
*
sizeof
(
T
))
FPx2
{
struct
alignas
(
2
*
sizeof
(
T
))
FPx2
{
T
x
;
T
x
;
...
@@ -834,6 +840,8 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
...
@@ -834,6 +840,8 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
}
#endif
}
// namespace ptx
}
// namespace ptx
namespace
{
namespace
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment