Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab122dac
Commit
ab122dac
authored
Mar 27, 2025
by
yuguo
Browse files
[DCU] compile pass
parent
4c6a5a27
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
367 additions
and
36 deletions
+367
-36
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+181
-0
transformer_engine/common/transpose/cast_transpose_fusion.cu
transformer_engine/common/transpose/cast_transpose_fusion.cu
+70
-1
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
...rmer_engine/common/transpose/rtc/cast_transpose_fusion.cu
+4
-0
transformer_engine/common/transpose/transpose_fusion.cu
transformer_engine/common/transpose/transpose_fusion.cu
+20
-2
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+6
-4
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+24
-4
transformer_engine/common/util/cuda_driver.cpp
transformer_engine/common/util/cuda_driver.cpp
+4
-0
transformer_engine/common/util/cuda_nvml.cpp
transformer_engine/common/util/cuda_nvml.cpp
+4
-0
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+6
-0
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+4
-3
transformer_engine/common/util/handle_manager.h
transformer_engine/common/util/handle_manager.h
+4
-0
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+4
-0
transformer_engine/common/util/rtc.cpp
transformer_engine/common/util/rtc.cpp
+4
-0
transformer_engine/common/util/rtc.h
transformer_engine/common/util/rtc.h
+5
-0
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+8
-1
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+1
-1
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+7
-14
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+8
-3
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
+2
-2
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+1
-1
No files found.
transformer_engine/common/swizzle/swizzle.cu
View file @
ab122dac
...
@@ -126,6 +126,84 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
...
@@ -126,6 +126,84 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
swizzle_col_scaling_kernel_int
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
int
)
/
sizeof
(
int
);
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
// input is in M-major
constexpr
int
SF_TILE_DIM_M_I32
=
SF_TILE_DIM_M
/
4
;
constexpr
int
SF_TILE_DIM_K_I32
=
SF_TILE_DIM_K
;
const
int
M_i32
=
M
/
4
;
const
int
K_i32
=
K
;
int
m_tiles_in_tb
=
N_TILE_PER_TD
;
int
k_tiles_in_tb
=
TB_DIM
;
if
(
blockIdx
.
x
==
gridDim
.
x
-
1
)
{
k_tiles_in_tb
=
(
K_i32
/
SF_TILE_DIM_K_I32
-
1
)
%
k_tiles_in_tb
+
1
;
}
if
(
blockIdx
.
y
==
gridDim
.
y
-
1
)
{
m_tiles_in_tb
=
(
M_i32
/
SF_TILE_DIM_M_I32
-
1
)
%
m_tiles_in_tb
+
1
;
}
const
int32_t
*
input_i32
=
reinterpret_cast
<
const
int32_t
*>
(
input
)
+
blockIdx
.
x
*
TB_DIM
*
SF_TILE_DIM_K_I32
*
M_i32
+
blockIdx
.
y
*
N_TILE_PER_TD
*
SF_TILE_DIM_M_I32
;
int32_t
*
output_i32
[
N_TILE_PER_TD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
output_i32
[
i
]
=
reinterpret_cast
<
int32_t
*>
(
output
)
+
blockIdx
.
x
*
TB_DIM
*
SF_TILE_SIZE_I32
+
(
blockIdx
.
y
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_DIM_M_I32
*
K_i32
;
}
extern
__shared__
int
slm
[];
// load, global -> regs
int
regs_vec
[
N_SF_PER_TD_PER_TILE
];
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
m_tiles_in_tb
*
SF_TILE_DIM_M_I32
&&
threadIdx
.
y
<
k_tiles_in_tb
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
*
reinterpret_cast
<
const
int
*>
(
input_i32
+
(
threadIdx
.
y
*
SF_TILE_DIM_K_I32
+
i
)
*
M_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
);
}
// local shuffle
regs_shuffle_with_bit_shifts
(
regs_vec
);
// store, regs -> shared
int
tM
=
threadIdx
.
x
*
N_SF_PER_TD
;
int
*
slm_tile
=
slm
+
(
threadIdx
.
y
*
SF_TILE_SIZE_I32
+
tM
/
SF_TILE_DIM_M
*
k_tiles_in_tb
*
SF_TILE_SIZE_I32
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD
;
i
++
)
{
/* TODO rotate_i */
slm_tile
[(
tM
%
SF_TILE_DIM_M
)
/
NEW_SF_TILE_DIM_M_I32
+
((
tM
+
i
)
%
NEW_SF_TILE_DIM_M_I32
)
*
NEW_SF_TILE_DIM_K_I32
]
=
reinterpret_cast
<
int
*>
(
regs_vec
)[
i
];
}
}
__syncthreads
();
// store, shared -> global
int
linear_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
__align__
(
16
)
int4
*
output_v4i
=
reinterpret_cast
<
int4
*>
(
output_i32
[
i
]);
__align__
(
16
)
int4
*
slm_v4i
=
reinterpret_cast
<
int4
*>
(
slm
+
i
*
k_tiles_in_tb
*
SF_TILE_SIZE_I32
);
#pragma unroll
for
(
int
j
=
linear_id
;
j
<
SF_TILE_SIZE_I32
*
k_tiles_in_tb
/
4
;
j
+=
blockDim
.
x
*
blockDim
.
y
)
{
output_v4i
[
j
]
=
slm_v4i
[
j
];
}
}
}
#endif
template
<
typename
LType
>
template
<
typename
LType
>
__device__
inline
void
regs_shuffle
(
LType
*
regs_vec
)
{
__device__
inline
void
regs_shuffle
(
LType
*
regs_vec
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
...
@@ -196,6 +274,61 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
...
@@ -196,6 +274,61 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
swizzle_row_scaling_kernel_int
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
int
)
/
sizeof
(
int
);
constexpr
int
N_TILES_IN_TB
=
TB_DIM
*
N_TILE_PER_TD
;
// input is in K-major
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
constexpr
int
SF_TILE_DIM_M_I32
=
SF_TILE_DIM_M
;
int
n_tiles_in_tb
=
N_TILES_IN_TB
;
const
int
K_i32
=
K
/
4
;
if
(
blockIdx
.
x
==
gridDim
.
x
-
1
)
{
n_tiles_in_tb
=
(
K_i32
-
1
)
%
N_TILES_IN_TB
+
1
;
}
const
int
*
input_i32
=
reinterpret_cast
<
const
int
*>
(
input
)
+
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
;
int
*
output_i32
=
reinterpret_cast
<
int
*>
(
output
)
+
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
*
SF_TILE_SIZE_I32
;
extern
__shared__
int4
slm_v4i
[];
// load, global -> regs
int
regs_vec
[
N_SF_PER_TD_PER_TILE
];
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
n_tiles_in_tb
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
*
reinterpret_cast
<
const
int
*>
(
input_i32
+
(
i
*
TB_DIM
+
threadIdx
.
y
)
*
K_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
);
}
// shuffle regs
regs_shuffle
<
int
>
(
regs_vec
);
// store, regs -> shared
#pragma unroll
for
(
int
i
=
0
;
i
<
N_TILE_PER_TD
;
i
++
)
{
/* TODO rotate i */
slm_v4i
[(
threadIdx
.
x
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_SIZE_I32
/
4
+
threadIdx
.
y
]
=
reinterpret_cast
<
int4
*>
(
regs_vec
)[
i
];
}
}
__syncthreads
();
// store, shared -> global
int
linear_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
__align__
(
16
)
int4
*
output_v4i
=
reinterpret_cast
<
int4
*>
(
output_i32
);
#pragma unroll
for
(
int
i
=
linear_id
;
i
<
SF_TILE_SIZE_I32
*
n_tiles_in_tb
/
4
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
output_v4i
[
i
]
=
slm_v4i
[
i
];
}
}
#endif
}
// namespace
}
// namespace
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -253,6 +386,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -253,6 +386,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
switch
(
vec_load_size
)
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel_int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel_int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
);
break
;
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
...
@@ -274,6 +430,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -274,6 +430,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
);
output
->
scale_inv
.
dptr
,
m
,
k
);
break
;
break
;
#endif
default:
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
break
;
...
@@ -286,6 +443,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -286,6 +443,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
switch
(
vec_load_size
)
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel_int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel_int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
break
;
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
...
@@ -307,6 +487,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -307,6 +487,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
break
;
break
;
#endif
default:
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
break
;
...
...
transformer_engine/common/transpose/cast_transpose_fusion.cu
View file @
ab122dac
...
@@ -170,7 +170,11 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
...
@@ -170,7 +170,11 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
#pragma unroll
#pragma unroll
for
(
unsigned
int
j
=
0
;
j
<
nvec_in
;
++
j
)
{
for
(
unsigned
int
j
=
0
;
j
<
nvec_in
;
++
j
)
{
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
#ifdef __HIP_PLATFORM_AMD__
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#else
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#endif
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
}
}
}
}
...
@@ -484,6 +488,50 @@ static const char *ActTypeToString[] = {
...
@@ -484,6 +488,50 @@ static const char *ActTypeToString[] = {
"dsrelu"
// 12
"dsrelu"
// 12
};
};
#ifdef __HIP_PLATFORM_AMD__
/* HIPCC has strict rules for __device__ functions usage on host.
It forbids not only calling but also other ODR-use assigning to variables
https://github.com/llvm/llvm-project/issues/105825
Use templated struct wrapper to work around
*/
template
<
typename
ComputeType
,
typename
ParamOP
,
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
struct
ActivationType
{
static
constexpr
auto
op
=
OP
;
};
template
<
typename
ComputeType
,
typename
ParamOP
,
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
int
get_activation_type
()
{
using
act
=
ActivationType
<
ComputeType
,
ParamOP
,
OP
>
;
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
sigmoid
<
ComputeType
,
ComputeType
>>::
op
)
{
return
1
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
dsigmoid
<
ComputeType
,
ComputeType
>>::
op
)
{
return
2
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
gelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
3
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
dgelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
4
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
qgelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
5
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
dqgelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
6
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
silu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
7
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
dsilu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
8
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
relu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
9
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
drelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
10
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
srelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
11
;
}
else
if
(
act
::
op
==
ActivationType
<
ComputeType
,
ParamOP
,
&
dsrelu
<
ComputeType
,
ComputeType
>>::
op
)
{
return
12
;
}
else
{
return
0
;
}
}
#else
template
<
typename
ComputeType
,
typename
ParamOP
,
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
template
<
typename
ComputeType
,
typename
ParamOP
,
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
constexpr
int
get_activation_type
()
{
constexpr
int
get_activation_type
()
{
constexpr
decltype
(
OP
)
ActivationList
[]
=
{
constexpr
decltype
(
OP
)
ActivationList
[]
=
{
...
@@ -509,6 +557,7 @@ constexpr int get_activation_type() {
...
@@ -509,6 +557,7 @@ constexpr int get_activation_type() {
}
}
return
0
;
return
0
;
}
}
#endif
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ComputeType
,
typename
ParamOP
,
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ComputeType
,
typename
ParamOP
,
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
ComputeType
(
*
OP
)(
ComputeType
,
const
ParamOP
&
)>
...
@@ -734,11 +783,17 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *
...
@@ -734,11 +783,17 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *
NVTE_CHECK
(
row_length
%
nvec_in
==
0
,
"Unsupported shape."
);
NVTE_CHECK
(
row_length
%
nvec_in
==
0
,
"Unsupported shape."
);
NVTE_CHECK
(
num_rows
%
nvec_out
==
0
,
"Unsupported shape."
);
NVTE_CHECK
(
num_rows
%
nvec_out
==
0
,
"Unsupported shape."
);
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute
((
const
void
*
)
cast_transpose_fused_kernel_notaligned
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Param
,
nvec_in
,
nvec_out
,
Empty
,
OP
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#else
cudaFuncSetAttribute
(
cudaFuncSetAttribute
(
cast_transpose_fused_kernel_notaligned
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
cast_transpose_fused_kernel_notaligned
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Param
,
nvec_in
,
nvec_out
,
Empty
,
OP
>
,
Param
,
nvec_in
,
nvec_out
,
Empty
,
OP
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#endif
cast_transpose_fused_kernel_notaligned
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Param
,
cast_transpose_fused_kernel_notaligned
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ComputeType
,
Param
,
nvec_in
,
nvec_out
,
Empty
,
OP
>
nvec_in
,
nvec_out
,
Empty
,
OP
>
<<<
num_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
<<<
num_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
...
@@ -1195,10 +1250,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
...
@@ -1195,10 +1250,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
const
size_t
shmem_size
=
cast_transpose_num_threads
/
n_warps_per_tile
*
const
size_t
shmem_size
=
cast_transpose_num_threads
/
n_warps_per_tile
*
(
THREADS_PER_WARP
+
1
)
*
sizeof
(
Vec
<
OutputType
,
nvec_out
>
);
(
THREADS_PER_WARP
+
1
)
*
sizeof
(
Vec
<
OutputType
,
nvec_out
>
);
if
(
full_tile
)
{
if
(
full_tile
)
{
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute
((
const
void
*
)
dgated_act_cast_transpose_kernel
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#else
cudaFuncSetAttribute
(
cudaFuncSetAttribute
(
dgated_act_cast_transpose_kernel
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
dgated_act_cast_transpose_kernel
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#endif
dgated_act_cast_transpose_kernel
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
dgated_act_cast_transpose_kernel
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
Empty
,
OP1
,
OP2
>
...
@@ -1212,10 +1274,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
...
@@ -1212,10 +1274,17 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
row_length
,
num_rows
,
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
row_length
,
num_rows
,
n_tiles
);
n_tiles
);
}
else
{
}
else
{
#ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute
((
const
void
*
)
dgated_act_cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#else
cudaFuncSetAttribute
(
cudaFuncSetAttribute
(
dgated_act_cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
ComputeType
,
dgated_act_cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
#endif
dgated_act_cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
dgated_act_cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
ComputeType
,
InputType
,
OutputType
,
Empty
,
OP1
,
OP2
>
OutputType
,
Empty
,
OP1
,
OP2
>
<<<
n_blocks
,
cast_transpose_num_threads
,
shmem_size
,
stream
>>>
(
<<<
n_blocks
,
cast_transpose_num_threads
,
shmem_size
,
stream
>>>
(
...
...
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
View file @
ab122dac
...
@@ -90,7 +90,11 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
...
@@ -90,7 +90,11 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
#pragma unroll
#pragma unroll
for
(
unsigned
int
j
=
0
;
j
<
NVEC_IN
;
++
j
)
{
for
(
unsigned
int
j
=
0
;
j
<
NVEC_IN
;
++
j
)
{
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
#ifdef __HIP_PLATFORM_AMD__
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#else
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#endif
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
}
}
}
}
...
...
transformer_engine/common/transpose/transpose_fusion.cu
View file @
ab122dac
...
@@ -45,7 +45,11 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
...
@@ -45,7 +45,11 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
#pragma unroll
#pragma unroll
for
(
unsigned
int
j
=
0
;
j
<
nvec_in
;
++
j
)
{
for
(
unsigned
int
j
=
0
;
j
<
nvec_in
;
++
j
)
{
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
#ifdef __HIP_PLATFORM_AMD__
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#else
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in warp
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in warp
#endif
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
out_dbias
.
data
.
elt
[
j
]
+=
elt
;
}
}
}
}
...
@@ -469,7 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
...
@@ -469,7 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
param
.
scale_inv
=
param
.
scale_inv
=
reinterpret_cast
<
const
ComputeType
*>
(
transposed_output
->
scale_inv
.
dptr
);
reinterpret_cast
<
const
ComputeType
*>
(
transposed_output
->
scale_inv
.
dptr
);
param
.
workspace
=
reinterpret_cast
<
ComputeType
*>
(
workspace
->
data
.
dptr
);
param
.
workspace
=
reinterpret_cast
<
ComputeType
*>
(
workspace
->
data
.
dptr
);
#ifdef __HIP_PLATFORM_AMD__
if
(
full_tile
)
{
cudaFuncSetAttribute
((
const
void
*
)
transpose_dbias_kernel
<
nvec_in
,
nvec_out
,
Param
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
transpose_dbias_kernel
<
nvec_in
,
nvec_out
,
Param
>
<<<
n_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
param
,
row_length
,
num_rows
,
n_tiles
);
}
else
{
cudaFuncSetAttribute
((
const
void
*
)
transpose_dbias_kernel_notaligned
<
nvec_in
,
nvec_out
,
Param
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
transpose_dbias_kernel_notaligned
<
nvec_in
,
nvec_out
,
Param
>
<<<
n_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
param
,
row_length
,
num_rows
,
n_tiles
);
}
#else
if
(
full_tile
)
{
if
(
full_tile
)
{
cudaFuncSetAttribute
(
transpose_dbias_kernel
<
nvec_in
,
nvec_out
,
Param
>
,
cudaFuncSetAttribute
(
transpose_dbias_kernel
<
nvec_in
,
nvec_out
,
Param
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
...
@@ -483,7 +501,7 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
...
@@ -483,7 +501,7 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
<<<
n_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
<<<
n_blocks
,
cast_transpose_num_threads
,
shared_size_transpose
,
stream
>>>
(
param
,
row_length
,
num_rows
,
n_tiles
);
param
,
row_length
,
num_rows
,
n_tiles
);
}
}
#endif
reduce_dbias
<
BiasType
>
(
*
workspace
,
dbias
,
row_length
,
num_rows
,
nvec_out
,
reduce_dbias
<
BiasType
>
(
*
workspace
,
dbias
,
row_length
,
num_rows
,
nvec_out
,
stream
););
// NOLINT(*)
stream
););
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
...
...
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
ab122dac
...
@@ -54,6 +54,7 @@ static_assert(ITERATIONS >= 1);
...
@@ -54,6 +54,7 @@ static_assert(ITERATIONS >= 1);
__device__
inline
float
sigmoidf
(
const
float
x
)
{
return
__frcp_rn
(
1.0
f
+
__expf
(
-
x
));
}
__device__
inline
float
sigmoidf
(
const
float
x
)
{
return
__frcp_rn
(
1.0
f
+
__expf
(
-
x
));
}
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
...
@@ -273,7 +274,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -273,7 +274,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif // __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
size_t
SCALE_DIM_Y
,
size_t
SCALE_DIM_X
>
size_t
SCALE_DIM_Y
,
size_t
SCALE_DIM_X
>
...
@@ -720,14 +723,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -720,14 +723,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif // __HIP_PLATFORM_AMD__
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Cast_fp8_gated is not surpported in rocm yet."
);
#else
#else
if
(
output
->
has_data
())
{
if
(
output
->
has_data
())
{
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated."
);
...
@@ -810,8 +813,7 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
...
@@ -810,8 +813,7 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Cast_mxfp8_gated is not surpported in rocm yet."
);
#else
#else
const
bool
USE_ROWWISE_SCALING
=
output
->
has_data
();
const
bool
USE_ROWWISE_SCALING
=
output
->
has_data
();
const
bool
USE_COLWISE_SCALING
=
output
->
has_columnwise_data
();
const
bool
USE_COLWISE_SCALING
=
output
->
has_columnwise_data
();
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
ab122dac
...
@@ -56,6 +56,7 @@ constexpr size_t MXFP8_BUFF_STAGES_NUM =
...
@@ -56,6 +56,7 @@ constexpr size_t MXFP8_BUFF_STAGES_NUM =
constexpr
size_t
MXFP8_ITERATIONS
=
MXFP8_CHUNK_DIM_Y
/
MXFP8_BUFFER_DIM_Y
;
// 2 = 64 / 32
constexpr
size_t
MXFP8_ITERATIONS
=
MXFP8_CHUNK_DIM_Y
/
MXFP8_BUFFER_DIM_Y
;
// 2 = 64 / 32
static_assert
(
MXFP8_ITERATIONS
>=
MXFP8_PREFETCH_BUFFERS_NUM
);
static_assert
(
MXFP8_ITERATIONS
>=
MXFP8_PREFETCH_BUFFERS_NUM
);
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
size_t
SCALE_DIM_Y
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
size_t
SCALE_DIM_Y
,
size_t
SCALE_DIM_X
>
size_t
SCALE_DIM_X
>
...
@@ -462,6 +463,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
...
@@ -462,6 +463,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
destroy_barriers
<
MXFP8_ITERATIONS
>
(
mbar
,
is_master_thread
);
destroy_barriers
<
MXFP8_ITERATIONS
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif // __HIP_PLATFORM_AMD__
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
...
@@ -479,6 +481,7 @@ constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
...
@@ -479,6 +481,7 @@ constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
constexpr
size_t
FP8_ITERATIONS
=
FP8_CHUNK_DIM_Y
/
FP8_BUFFER_DIM_Y
;
// 8 = 128 / 16
constexpr
size_t
FP8_ITERATIONS
=
FP8_CHUNK_DIM_Y
/
FP8_BUFFER_DIM_Y
;
// 8 = 128 / 16
static_assert
(
FP8_ITERATIONS
>=
FP8_PREFETCH_BUFFERS_NUM
);
static_assert
(
FP8_ITERATIONS
>=
FP8_PREFETCH_BUFFERS_NUM
);
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
>
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
FP8_THREADS_PER_CHUNK
)
__global__
void
__launch_bounds__
(
FP8_THREADS_PER_CHUNK
)
...
@@ -656,6 +659,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
...
@@ -656,6 +659,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
destroy_barriers
<
FP8_ITERATIONS
>
(
mbar
,
is_master_thread
);
destroy_barriers
<
FP8_ITERATIONS
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif // __HIP_PLATFORM_AMD__
constexpr
size_t
CHUNKS_PER_BLOCK
=
128
;
constexpr
size_t
CHUNKS_PER_BLOCK
=
128
;
constexpr
size_t
THREADS_PER_BLOCK
=
FP8_THREADS_PER_CHUNK
;
constexpr
size_t
THREADS_PER_BLOCK
=
FP8_THREADS_PER_CHUNK
;
...
@@ -856,8 +860,7 @@ template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, cons
...
@@ -856,8 +860,7 @@ template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, cons
void
cast_fp8_2D
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
Tensor
*
output
,
Tensor
*
dbias
,
void
cast_fp8_2D
(
const
Tensor
&
input
,
const
Tensor
*
act_input
,
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
Tensor
*
workspace
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Cast_fp8_2D is not surpported in rocm yet."
);
#else
#else
checkCuDriverContext
(
stream
);
checkCuDriverContext
(
stream
);
...
@@ -931,8 +934,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
...
@@ -931,8 +934,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const
Tensor
*
noop
,
// TODO (ksivamani)
const
Tensor
*
noop
,
// TODO (ksivamani)
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
Tensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Mxfp8_quantize is not surpported in rocm yet."
);
#else
#else
bool
use_rowwise_scaling
=
output
->
has_data
();
bool
use_rowwise_scaling
=
output
->
has_data
();
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
...
@@ -1057,10 +1059,23 @@ __device__ inline float dequantize_func(float value, const DequantizeParam ¶
...
@@ -1057,10 +1059,23 @@ __device__ inline float dequantize_func(float value, const DequantizeParam ¶
}
// namespace detail
}
// namespace detail
#ifdef __HIP_PLATFORM_AMD__
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
struct
KernelType
{
static
constexpr
auto
op
=
OP
;
};
#endif
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
CastVectorizedUnaryKernelLauncher
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
void
CastVectorizedUnaryKernelLauncher
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
using
kernel
=
KernelType
<
ParamOP
,
OP
>
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
kernel
::
op
==
nullptr
)
?
KernelType
<
ParamOP
,
&
detail
::
identity
>::
op
:
kernel
::
op
;
#else
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
#endif
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
size_t
N
=
product
(
input
.
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
input
.
data
.
dtype
,
IType
,
...
@@ -1084,7 +1099,12 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
...
@@ -1084,7 +1099,12 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
template
<
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
CastVectorizedUnaryGradKernelLauncher
(
const
Tensor
&
grad
,
const
Tensor
*
input
,
Tensor
*
output
,
void
CastVectorizedUnaryGradKernelLauncher
(
const
Tensor
&
grad
,
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
using
kernel
=
KernelType
<
ParamOP
,
OP
>
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
kernel
::
op
==
nullptr
)
?
KernelType
<
ParamOP
,
&
detail
::
identity
>::
op
:
kernel
::
op
;
#else
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
constexpr
float
(
*
UnaryOP
)(
float
,
const
ParamOP
&
)
=
(
OP
==
nullptr
)
?
detail
::
identity
:
OP
;
#endif
const
size_t
N
=
product
(
input
->
data
.
shape
);
const
size_t
N
=
product
(
input
->
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
->
data
.
dtype
,
IType
,
input
->
data
.
dtype
,
IType
,
...
...
transformer_engine/common/util/cuda_driver.cpp
View file @
ab122dac
...
@@ -7,7 +7,11 @@
...
@@ -7,7 +7,11 @@
#include <filesystem>
#include <filesystem>
#include "../common.h"
#include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_runtime.h"
#include "../util/cuda_runtime.h"
#endif
namespace
transformer_engine
{
namespace
transformer_engine
{
...
...
transformer_engine/common/util/cuda_nvml.cpp
View file @
ab122dac
...
@@ -4,7 +4,11 @@
...
@@ -4,7 +4,11 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifdef USE_ROCM
#include "hip_nvml.h"
#else
#include "cuda_nvml.h"
#include "cuda_nvml.h"
#endif
#include "shared_lib_wrapper.h"
#include "shared_lib_wrapper.h"
...
...
transformer_engine/common/util/cuda_runtime.cpp
View file @
ab122dac
...
@@ -10,9 +10,15 @@
...
@@ -10,9 +10,15 @@
#include <mutex>
#include <mutex>
#include "../common.h"
#include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_driver.h"
#include "../util/system.h"
#include "common/util/hip_runtime.h"
#else
#include "../util/cuda_driver.h"
#include "../util/cuda_driver.h"
#include "../util/system.h"
#include "../util/system.h"
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
namespace
transformer_engine
{
namespace
transformer_engine
{
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
ab122dac
...
@@ -50,6 +50,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X;
...
@@ -50,6 +50,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X;
constexpr
size_t
ITERATIONS
=
CHUNK_DIM_Y
/
BUFFER_DIM_Y
;
// 8 = 128 / 16
constexpr
size_t
ITERATIONS
=
CHUNK_DIM_Y
/
BUFFER_DIM_Y
;
// 8 = 128 / 16
static_assert
(
ITERATIONS
>=
1
);
static_assert
(
ITERATIONS
>=
1
);
#ifndef __HIP_PLATFORM_AMD__
template
<
typename
IType
,
typename
OType
,
size_t
SCALE_DIM_Y
,
size_t
SCALE_DIM_X
>
template
<
typename
IType
,
typename
OType
,
size_t
SCALE_DIM_Y
,
size_t
SCALE_DIM_X
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
dequantize_mxfp8_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
dequantize_mxfp8_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
...
@@ -229,6 +230,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -229,6 +230,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
#endif // __HIP_PLATFORM_AMD__
static
void
fp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
static
void
fp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
),
"Input must have FP8 type."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
),
"Input must have FP8 type."
);
...
@@ -253,8 +255,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
...
@@ -253,8 +255,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
static
void
mxfp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
static
void
mxfp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Mxfp8_dequantize is not surpported in rocm yet."
);
#else
#else
bool
use_rowwise_scaling
=
input
.
has_data
();
bool
use_rowwise_scaling
=
input
.
has_data
();
bool
use_colwise_scaling
=
input
.
has_columnwise_data
();
bool
use_colwise_scaling
=
input
.
has_columnwise_data
();
...
@@ -337,8 +338,8 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
...
@@ -337,8 +338,8 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
}
#endif
#endif
}
}
// namespace dequantization
}
// namespace dequantization
namespace
detail
{
namespace
detail
{
...
...
transformer_engine/common/util/handle_manager.h
View file @
ab122dac
...
@@ -9,7 +9,11 @@
...
@@ -9,7 +9,11 @@
#include <vector>
#include <vector>
#ifdef __HIP_PLATFORM_AMD__
#include "hip_runtime.h"
#else
#include "cuda_runtime.h"
#include "cuda_runtime.h"
#endif
#include "logging.h"
#include "logging.h"
namespace
transformer_engine
::
detail
{
namespace
transformer_engine
::
detail
{
...
...
transformer_engine/common/util/pybind_helper.h
View file @
ab122dac
...
@@ -12,7 +12,11 @@
...
@@ -12,7 +12,11 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
#ifdef __HIP_PLATFORM_AMD__
#include "hip_runtime.h"
#else
#include "cuda_runtime.h"
#include "cuda_runtime.h"
#endif
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
...
...
transformer_engine/common/util/rtc.cpp
View file @
ab122dac
...
@@ -11,7 +11,11 @@
...
@@ -11,7 +11,11 @@
#include <utility>
#include <utility>
#include "../common.h"
#include "../common.h"
#ifdef USE_ROCM
#include "../util/hip_driver.h"
#else
#include "../util/cuda_driver.h"
#include "../util/cuda_driver.h"
#endif
#include "../util/string.h"
#include "../util/string.h"
#include "../util/system.h"
#include "../util/system.h"
...
...
transformer_engine/common/util/rtc.h
View file @
ab122dac
...
@@ -19,8 +19,13 @@
...
@@ -19,8 +19,13 @@
#include <vector>
#include <vector>
#include "../common.h"
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_driver.h"
#include "../util/hip_runtime.h"
#else
#include "../util/cuda_driver.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/cuda_runtime.h"
#endif
namespace
transformer_engine
{
namespace
transformer_engine
{
...
...
transformer_engine/common/utils.cuh
View file @
ab122dac
...
@@ -258,7 +258,7 @@ struct Converter<float2, hip_bfloat16x2> {
...
@@ -258,7 +258,7 @@ struct Converter<float2, hip_bfloat16x2> {
static
inline
__device__
hip_bfloat16x2
convert
(
const
float2
&
x
)
{
static
inline
__device__
hip_bfloat16x2
convert
(
const
float2
&
x
)
{
union
{
union
{
hip_bfloat16x2
raw
;
hip_bfloat16x2
raw
;
hip_bfloat16
elt
[
2
];
__
hip_bfloat16
elt
[
2
];
}
tmp
;
}
tmp
;
tmp
.
elt
[
0
]
=
__hip_bfloat16
(
x
.
x
);
tmp
.
elt
[
0
]
=
__hip_bfloat16
(
x
.
x
);
tmp
.
elt
[
1
]
=
__hip_bfloat16
(
x
.
y
);
tmp
.
elt
[
1
]
=
__hip_bfloat16
(
x
.
y
);
...
@@ -1020,6 +1020,13 @@ struct Quantized_Limits {
...
@@ -1020,6 +1020,13 @@ struct Quantized_Limits {
static
constexpr
float
emax_rcp
=
1.0
/
emax
;
static
constexpr
float
emax_rcp
=
1.0
/
emax
;
};
};
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
// TODO: nan/inf needs to be set for any value
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
// of nan/inf in input not just amax.
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
ab122dac
...
@@ -108,7 +108,7 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
...
@@ -108,7 +108,7 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
bool
use_split_accumulator
,
int
math_sm_count
)
;
#endif
#endif
/***************************************************************************************************
/***************************************************************************************************
...
...
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
ab122dac
...
@@ -18,8 +18,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
...
@@ -18,8 +18,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Get_fused_attn_backend is not surpported in rocm for normalization yet."
);
#else
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
...
@@ -101,8 +100,7 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -101,8 +100,7 @@ std::vector<py::object> fused_attn_fwd(
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Fused_attn_fwd is not surpported in rocm for normalization yet."
);
#else
#else
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
...
@@ -294,8 +292,7 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -294,8 +292,7 @@ std::vector<py::object> fused_attn_bwd(
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
static_assert
(
false
,
assert
(
false
);
"Fused_attn_bwd is not surpported in rocm for normalization yet."
);
#else
#else
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
...
@@ -1051,8 +1048,7 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
...
@@ -1051,8 +1048,7 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
convert_thd_to_bshd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
void
convert_thd_to_bshd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
transformer_engine
::
fused_attn
::
transformer_engine
::
fused_attn
::
convert_thd_to_bshd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
convert_thd_to_bshd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
b
,
max_seq_len
,
h
,
d
);
b
,
max_seq_len
,
h
,
d
);
...
@@ -1091,8 +1087,7 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
...
@@ -1091,8 +1087,7 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
convert_bshd_to_thd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
void
convert_bshd_to_thd_launcher
(
at
::
Tensor
tensor
,
at
::
Tensor
new_tensor
,
at
::
Tensor
cu_seqlens
,
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
int
b
,
int
max_seq_len
,
int
h
,
int
d
)
{
transformer_engine
::
fused_attn
::
transformer_engine
::
fused_attn
::
convert_bshd_to_thd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
convert_bshd_to_thd_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
tensor
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
reinterpret_cast
<
scalar_t
*>
(
new_tensor
.
data_ptr
<
scalar_t
>
()),
cu_seqlens
.
data_ptr
<
int
>
(),
b
,
max_seq_len
,
h
,
d
);
b
,
max_seq_len
,
h
,
d
);
...
@@ -1152,15 +1147,13 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_
...
@@ -1152,15 +1147,13 @@ void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_
if
(
new_k
.
data_ptr
()
!=
nullptr
&&
new_v
.
data_ptr
()
!=
nullptr
&&
k_cache
.
data_ptr
()
!=
nullptr
&&
if
(
new_k
.
data_ptr
()
!=
nullptr
&&
new_v
.
data_ptr
()
!=
nullptr
&&
k_cache
.
data_ptr
()
!=
nullptr
&&
v_cache
.
data_ptr
()
!=
nullptr
)
{
v_cache
.
data_ptr
()
!=
nullptr
)
{
if
(
is_non_paged
)
{
if
(
is_non_paged
)
{
transformer_engine
::
fused_attn
::
transformer_engine
::
fused_attn
::
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reindex_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
v_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
v_cache
.
data_ptr
<
scalar_t
>
()),
page_table
.
data_ptr
<
int
>
(),
cu_new_lens
.
data_ptr
<
int
>
(),
page_table
.
data_ptr
<
int
>
(),
cu_new_lens
.
data_ptr
<
int
>
(),
cu_cached_lens
.
data_ptr
<
int
>
(),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
cu_cached_lens
.
data_ptr
<
int
>
(),
h_kv
,
d_k
,
d_v
,
b
,
max_seq_len
);
}
}
transformer_engine
::
fused_attn
::
transformer_engine
::
fused_attn
::
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
copy_to_kv_cache_kernel
<<<
16
,
256
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
reinterpret_cast
<
scalar_t
*>
(
new_k
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_k
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_v
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
new_v
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
reinterpret_cast
<
scalar_t
*>
(
k_cache
.
data_ptr
<
scalar_t
>
()),
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
ab122dac
...
@@ -12,7 +12,11 @@
...
@@ -12,7 +12,11 @@
#include "../common.h"
#include "../common.h"
#include "common.h"
#include "common.h"
#ifdef USE_ROCM
#include "common/util/hip_runtime.h"
#else
#include "common/util/cuda_runtime.h"
#include "common/util/cuda_runtime.h"
#endif
#include "common/util/system.h"
#include "common/util/system.h"
#include "extensions.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
...
@@ -531,9 +535,10 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
...
@@ -531,9 +535,10 @@ std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
}
// For now, we only have multi-stream cublas backend.
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm
(
te_A
.
data
(),
te_B
.
data
(),
te_D
.
data
(),
te_bias
.
data
(),
nvte_multi_stream_cublas_batchgemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_pre_gelu_out
.
data
(),
te_A
.
size
(),
transa
,
transb
,
grad
,
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
return
bias
;
return
bias
;
}
}
...
...
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
View file @
ab122dac
...
@@ -306,7 +306,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -306,7 +306,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
rowwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
rowwise_scale_inv
=
at
::
zeros
({
sinv0
,
sinv1
},
opts
);
tensor
.
set_rowwise_data
(
data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_data
(
data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_rowwise_scale_inv
(
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
tensor
.
set_rowwise_scale_inv
(
rowwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
std
::
vector
<
size_t
>
{
sinv0
,
sinv1
});
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
sinv0
)
,
static_cast
<
size_t
>
(
sinv1
)
});
}
}
if
(
columnwise_usage
)
{
if
(
columnwise_usage
)
{
...
@@ -317,7 +317,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
...
@@ -317,7 +317,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
tensor
.
set_columnwise_data
(
columnwise_data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_columnwise_data
(
columnwise_data
.
data_ptr
(),
this
->
dtype
,
shape
);
tensor
.
set_columnwise_scale_inv
(
columnwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
tensor
.
set_columnwise_scale_inv
(
columnwise_scale_inv
.
data_ptr
(),
DType
::
kFloat8E8M0
,
std
::
vector
<
size_t
>
{
sinv0
,
sinv1
});
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
sinv0
)
,
static_cast
<
size_t
>
(
sinv1
)
});
}
}
this
->
set_quantization_params
(
&
tensor
);
this
->
set_quantization_params
(
&
tensor
);
...
...
transformer_engine/pytorch/module/base.py
View file @
ab122dac
...
@@ -92,7 +92,7 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
...
@@ -92,7 +92,7 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
if
not
_multi_stream_cublas_batchgemm_workspace
:
if
not
_multi_stream_cublas_batchgemm_workspace
:
for
_
in
range
(
tex
.
_num_cublas_batchgemm_streams
):
for
_
in
range
(
tex
.
_num_cublas_batchgemm_streams
):
_multi_stream_cublas_batchgemm_workspace
.
append
(
_multi_stream_cublas_batchgemm_workspace
.
append
(
torch
.
empty
(
get_cublas_workspace_size_bytes
()
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
torch
.
empty
(
128
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
)
return
_multi_stream_cublas_batchgemm_workspace
return
_multi_stream_cublas_batchgemm_workspace
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment