Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
30632f31
Unverified
Commit
30632f31
authored
Mar 17, 2023
by
Tim Moon
Committed by
GitHub
Mar 17, 2023
Browse files
Use 4B vector loads/stores in cast-transpose kernel for small matrices (#101)
Signed-off-by:
Tim Moon
<
tmoon@nvidia.com
>
parent
277b0be2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
51 deletions
+82
-51
transformer_engine/common/transpose/cast_transpose.cu
transformer_engine/common/transpose/cast_transpose.cu
+82
-51
No files found.
transformer_engine/common/transpose/cast_transpose.cu
View file @
30632f31
...
...
@@ -47,8 +47,6 @@ inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
// STUFF TO TUNE
constexpr
unsigned
int
n_warps_per_tile
=
4
;
constexpr
int
desired_load_size
=
8
;
constexpr
int
desired_store_size
=
8
;
constexpr
unsigned
int
max_threads_per_block
=
256
;
static_assert
(
n_warps_per_tile
*
THREADS_PER_WARP
<=
max_threads_per_block
);
...
...
@@ -321,61 +319,94 @@ void cast_transpose(const Tensor &input,
NVTE_CHECK
(
cast_output
->
scale
.
dptr
==
transposed_output
->
scale
.
dptr
,
"C and T outputs need to share scale tensor."
);
// Launch specific cast-transpose kernel
#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \
do { \
cudaFuncSetAttribute(kernel<nvec_in, nvec_out, fp32, InputType, OutputType>, \
cudaFuncAttributePreferredSharedMemoryCarveout, \
100); \
kernel<nvec_in, nvec_out, fp32, InputType, OutputType> \
<<<n_blocks, \
cast_transpose_num_threads, \
cast_transpose_num_threads / n_warps_per_tile * \
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \
stream>>>( \
reinterpret_cast<const InputType *>(input.data.dptr), \
reinterpret_cast<OutputType *>(cast_output->data.dptr), \
reinterpret_cast<OutputType *>(transposed_output->data.dptr), \
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \
reinterpret_cast<fp32 *>(cast_output->amax.dptr), \
row_length, num_rows, n_tiles); \
} while (false)
// Launch cast-transpose kernel for given vector sizes
#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \
do { \
constexpr int nvec_in = load_size / sizeof(InputType); \
constexpr int nvec_out = store_size / sizeof(OutputType); \
\
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \
\
const size_t n_tiles = get_n_tiles(load_size, store_size); \
const size_t n_blocks = get_n_blocks(n_tiles); \
\
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \
num_rows % (nvec_out * THREADS_PER_WARP) == 0; \
\
if (full_tile) { \
LAUNCH_KERNEL(cast_transpose_kernel, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} else { \
LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \
nvec_in, nvec_out, n_tiles, n_blocks, \
InputType, OutputType); \
} \
} while (false)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
cast_output
->
data
.
dtype
,
OutputType
,
constexpr
int
itype_size
=
sizeof
(
InputType
);
constexpr
int
otype_size
=
sizeof
(
OutputType
);
constexpr
int
nvec_in
=
desired_load_size
/
itype_size
;
constexpr
int
nvec_out
=
desired_store_size
/
otype_size
;
NVTE_CHECK
(
row_length
%
nvec_in
==
0
,
"Unsupported shape."
);
NVTE_CHECK
(
num_rows
%
nvec_out
==
0
,
"Unsupported shape."
);
const
size_t
n_tiles
=
DIVUP
(
row_length
,
static_cast
<
size_t
>
(
nvec_in
*
THREADS_PER_WARP
))
*
DIVUP
(
num_rows
,
static_cast
<
size_t
>
(
nvec_out
*
THREADS_PER_WARP
));
const
size_t
n_warps_per_block
=
cast_transpose_num_threads
/
THREADS_PER_WARP
;
const
size_t
n_blocks
=
DIVUP
(
n_tiles
*
n_warps_per_tile
,
n_warps_per_block
);
const
bool
full_tile
=
row_length
%
(
nvec_in
*
THREADS_PER_WARP
)
==
0
&&
num_rows
%
(
nvec_out
*
THREADS_PER_WARP
)
==
0
;
if
(
full_tile
)
{
cudaFuncSetAttribute
(
cast_transpose_kernel
<
nvec_in
,
nvec_out
,
fp32
,
InputType
,
OutputType
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cast_transpose_kernel
<
nvec_in
,
nvec_out
,
fp32
,
InputType
,
OutputType
>
<<<
n_blocks
,
cast_transpose_num_threads
,
cast_transpose_num_threads
/
n_warps_per_tile
*
(
THREADS_PER_WARP
+
1
)
*
sizeof
(
Vec
<
OutputType
,
nvec_out
>
),
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
cast_output
->
data
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
transposed_output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
cast_output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
cast_output
->
amax
.
dptr
),
row_length
,
num_rows
,
n_tiles
);
// Estimate number of SMs
// Note: H100 has 132 SMs, A100 has 108 SMs.
// Note: Directly querying number of SMs with cudaGetDeviceProperties is
// slow (>1 ms). Consider querying once and caching.
const
int
n_sms
=
128
;
// Helper functions to get kernel configuration
auto
get_n_tiles
=
[
=
]
(
size_t
load_size
,
size_t
store_size
)
->
int
{
constexpr
size_t
threads_per_warp
=
static_cast
<
size_t
>
(
THREADS_PER_WARP
);
size_t
nvec_in
=
load_size
/
sizeof
(
InputType
);
size_t
nvec_out
=
store_size
/
sizeof
(
OutputType
);
size_t
n_tiles
=
DIVUP
(
row_length
,
nvec_in
*
threads_per_warp
)
*
DIVUP
(
num_rows
,
nvec_out
*
threads_per_warp
);
return
n_tiles
;
};
auto
get_n_blocks
=
[
=
]
(
size_t
n_tiles
)
->
int
{
size_t
n_warps_per_block
=
cast_transpose_num_threads
/
THREADS_PER_WARP
;
size_t
n_blocks
=
DIVUP
(
n_tiles
*
n_warps_per_tile
,
n_warps_per_block
);
return
n_blocks
;
};
// Estimate optimal vector sizes and run
// Note: Consider reducing to 2B or 1B loads/stores for
// sufficiently small matrices. Need to consider whether reduced
// cache efficiency is worth increased SM utilization. Also need
// to keep in mind whether datatype can fit.
const
size_t
estimated_n_tiles
=
get_n_tiles
(
8
,
8
);
const
size_t
estimated_n_blocks
=
get_n_blocks
(
estimated_n_tiles
);
if
(
estimated_n_blocks
>=
n_sms
)
{
LAUNCH_KERNEL_VEC_SIZES
(
8
,
8
,
InputType
,
OutputType
);
}
else
{
cudaFuncSetAttribute
(
cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
fp32
,
InputType
,
OutputType
>
,
cudaFuncAttributePreferredSharedMemoryCarveout
,
100
);
cast_transpose_kernel_notaligned
<
nvec_in
,
nvec_out
,
fp32
,
InputType
,
OutputType
>
<<<
n_blocks
,
cast_transpose_num_threads
,
cast_transpose_num_threads
/
n_warps_per_tile
*
(
THREADS_PER_WARP
+
1
)
*
sizeof
(
Vec
<
OutputType
,
nvec_out
>
),
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
cast_output
->
data
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
transposed_output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
cast_output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
cast_output
->
amax
.
dptr
),
row_length
,
num_rows
,
n_tiles
);
LAUNCH_KERNEL_VEC_SIZES
(
4
,
4
,
InputType
,
OutputType
);
}
);
// NOLINT(*)
);
// NOLINT(*)
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_VEC_SIZES
}
}
// namespace transformer_engine
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment