Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7aeb5a72
Commit
7aeb5a72
authored
Jan 23, 2026
by
wenjh
Browse files
Merge branch 'develop_v2.10' into release_v2.10
parents
d786eedd
261e476b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
167 deletions
+81
-167
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+81
-167
No files found.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
7aeb5a72
...
@@ -989,15 +989,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -989,15 +989,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
constexpr
size_t
kThreadsPerWarp_blocklen_128
=
64
;
constexpr
size_t
kThreadsPerWarp_blocklen_128
=
64
;
constexpr
int
kTileDim64_Rowwise
=
64
;
// Optimized Rowwise kernel: Direct register processing without shared memory
constexpr
int
kNVecSMem_Rowwise
=
4
;
// The number of elements each LDS/STS touches
// Each warp (64 threads) processes multiple rows, 8 threads collaborate on one 128-element row
constexpr
int
kThreadsPerBlock_Rowwise
=
512
;
// Thread block size, 8 warps in total
constexpr
int
kThreadsPerBlock_Rowwise_Opt
=
512
;
constexpr
int
kSMemRow_Rowwise
=
kTileDim64_Rowwise
;
constexpr
int
kThreadsPerRow_Rowwise
=
8
;
// 8 threads per row, each handles 16 elements = 128 total
constexpr
int
kSMemCol_Rowwise
=
(
kTileDim
/
kNVecSMem_Rowwise
);
constexpr
int
kRowsPerBlock_Rowwise
=
kThreadsPerBlock_Rowwise_Opt
/
kThreadsPerRow_Rowwise
;
// 64 rows per block
constexpr
int
kSMemSize_Rowwise
=
kSMemRow_Rowwise
*
kSMemCol_Rowwise
*
kNVecSMem_Rowwise
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock_Rowwise
)
__global__
void
__launch_bounds__
(
kThreadsPerBlock_Rowwise
_Opt
,
4
)
block_scaled_1d_cast_transpose_kernel_rowwise
(
const
IType
*
const
input
,
OType
*
const
output_c
,
block_scaled_1d_cast_transpose_kernel_rowwise
(
const
IType
*
const
input
,
OType
*
const
output_c
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_c
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
row_length
,
const
size_t
num_rows
,
...
@@ -1008,168 +1007,90 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
...
@@ -1008,168 +1007,90 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
return
;
}
}
bool
return_rowwise
=
rowwise_option
!
=
FP8BlockwiseRowwiseOption
::
NONE
;
if
(
rowwise_option
=
=
FP8BlockwiseRowwiseOption
::
NONE
)
return
;
using
SMem
Vec
=
Vec
<
IType
,
kNVec
SMem_Rowwise
>
;
using
I
Vec
=
Vec
<
IType
,
kNVec
Out
>
;
// 16 elements per thread
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem_Rowwise
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
// Thread indexing: 8 threads per row, 64 rows per block
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
smem_base
);
const
int
thr_in_row
=
threadIdx
.
x
%
kThreadsPerRow_Rowwise
;
// 0-7: position within row
// Step 1: Load input to shared memory
const
int
row_in_block
=
threadIdx
.
x
/
kThreadsPerRow_Rowwise
;
// 0-63: which row in block
{
constexpr
int
r_stride
=
kThreadsPerBlock_Rowwise
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim64_Rowwise
/
r_stride
;
//64/16=4
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem_Rowwise
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem_Rowwise
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64_Rowwise
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory - row Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem_Rowwise
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// row Major Store
smem
[
r
*
kSMemCol_Rowwise
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
// Global position
const
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kRowsPerBlock_Rowwise
+
row_in_block
;
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
thr_in_row
*
kNVecOut
;
// Step 2: Cast and store to output_c
// Early exit if out of bounds
if
(
return_rowwise
)
{
if
constexpr
(
!
kAligned
)
{
constexpr
int
r_stride
=
if
(
r_g
>=
num_rows
)
return
;
kThreadsPerBlock_Rowwise
/
kNumThreadsStore
;
// stride in rows of shared memory
}
constexpr
int
num_iterations
=
kTileDim64_Rowwise
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem_Rowwise
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem_Rowwise
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64_Rowwise
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp_blocklen_128
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
// Calculate number of elements for non-aligned case
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
const
size_t
num_ele
=
c_g
<
row_length
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem_Rowwise
];
?
(
c_g
+
kNVecOut
<=
row_length
?
kNVecOut
:
row_length
-
c_g
)
// Step 2.1: Load from shared memory to registers - Column Major
:
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem_Rowwise
;
++
i
)
{
// Step 1: Load directly from global memory to registers (NO shared memory!)
int
c
=
c_s
+
i
;
IVec
input_vec
;
int
r
=
r_s
;
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Column Major Read
if
constexpr
(
kAligned
)
{
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol_Rowwise
+
c
];
input_vec
.
load_from
(
input_g
);
}
}
else
{
if
(
num_ele
>
0
)
{
input_vec
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
clear
();
}
}
// Step
2.
2: Compute local amax
// Step 2: Compute local amax
(16 elements per thread)
CType
amax
=
0
;
CType
amax
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem_Rowwise
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
#pragma unroll
__builtin_assume
(
amax
>=
0
);
for
(
int
j
=
0
;
j
<
kNVecSMem_Rowwise
;
++
j
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
static_cast
<
CType
>
(
input_vec
.
data
.
elt
[
i
])));
__builtin_assume
(
amax
>=
0
);
}
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step
2.
3: Reduce amax
// Step 3: Reduce amax
across 8 threads (128 elements total)
#pragma unroll
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kThreadsPerRow_Rowwise
/
2
;
delta
>
0
;
delta
/=
2
)
{
//const float other_amax =__shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const
float
other_amax
=
__shfl_xor
(
amax
,
delta
,
kThreadsPerRow_Rowwise
);
const
float
other_amax
=
__shfl_xor
(
amax
,
delta
,
kThreadsPerWarp_blocklen_128
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
fmaxf
(
amax
,
other_amax
);
// Step 4: Compute scale
}
CType
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
CType
scale
;
// Step 5: Write scale_inv (only first thread in each row)
// Step 2.4: Compute scale
if
(
thr_in_row
==
0
)
{
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
CType
scale_inv
=
1.0
f
/
scale
;
// Step 2.5: Write scale_inv
size_t
row_idx
=
r_g
;
bool
write_scale_inv
=
is_src_lane
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
if
constexpr
(
!
kAligned
)
{
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
write_scale_inv
&=
(
r_g
<
num_rows
);
}
}
if
(
write_scale_inv
)
{
// Step 6: Quantize directly in registers
CType
scale_inv
=
1.0
/
scale
;
OVec
output_vec
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64_Rowwise
+
r_s
;
//
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem_Rowwise
;
++
i
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem_Rowwise
;
++
j
)
{
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem_Rowwise
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
input_vec
.
data
.
elt
[
i
])
*
scale
))));
}
else
{
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem_Rowwise
+
j
]
=
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
static_cast
<
OType
>
(
static_cast
<
CType
>
(
input_vec
.
data
.
elt
[
i
])
*
scale
);
}
}
}
}
}
// Step 2.7: Store output_c
// Step 7: Store directly to global memory
if
constexpr
(
kAligned
)
{
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
output_vec
.
store_to
(
output_g
);
if
constexpr
(
kAligned
)
{
}
else
{
output_vec
.
store_to
(
output_g
);
if
(
r_g
<
num_rows
)
{
}
else
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
if
(
num_ele
>
0
)
{
}
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
}
}
}
}
...
@@ -1177,7 +1098,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
...
@@ -1177,7 +1098,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock_Rowwise)
constexpr
int
kTileDim64_Colwise
=
64
;
constexpr
int
kTileDim64_Colwise
=
64
;
constexpr
int
kNVecSMem_Colwise
=
2
;
constexpr
int
kNVecSMem_Colwise
=
2
;
constexpr
int
kSMemRow_Colwise
=
kTileDim
;
constexpr
int
kSMemRow_Colwise
=
kTileDim
;
constexpr
int
kSMemCol_Colwise
=
(
kTileDim64_Colwise
/
kNVecSMem_Colwise
)
;
constexpr
int
kSMemCol_Colwise
=
(
kTileDim64_Colwise
/
kNVecSMem_Colwise
)
+
1
;
// Padding to avoid bank conflict
constexpr
int
kSMemSize_Colwise
=
kSMemRow_Colwise
*
kSMemCol_Colwise
*
kNVecSMem_Colwise
;
constexpr
int
kSMemSize_Colwise
=
kSMemRow_Colwise
*
kSMemCol_Colwise
*
kNVecSMem_Colwise
;
constexpr
int
kNumThreadsLoad_Colwise
=
kTileDim64_Colwise
/
kNVecIn
;
constexpr
int
kNumThreadsLoad_Colwise
=
kTileDim64_Colwise
/
kNVecIn
;
constexpr
int
kNumThreadsStore_Colwise
=
kTileDim
/
kNVecOut
;
constexpr
int
kNumThreadsStore_Colwise
=
kTileDim
/
kNVecOut
;
...
@@ -1583,19 +1504,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -1583,19 +1504,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
while
(
true
)
{
while
(
true
)
{
if
(
128
==
block_len
)
{
if
(
128
==
block_len
)
{
if
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
)
{
if
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
)
{
size_t
smem_bytes
=
kSMemSize_Rowwise
*
sizeof
(
InputType
);
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)(
block_len
));
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)(
block_len
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)(
block_len
/
2
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
kRowsPerBlock_Rowwise
);
if
(
smem_bytes
>=
48
*
1024
)
{
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel_rowwise
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel_rowwise
<
kAligned
,
float
,
InputType
,
block_scaled_1d_cast_transpose_kernel_rowwise
<
kAligned
,
float
,
InputType
,
OutputType
>
OutputType
>
<<<
grid
,
kThreadsPerBlock_Rowwise
,
smem_bytes
,
stream
>>>
(
<<<
grid
,
kThreadsPerBlock_Rowwise
_Opt
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
row_length
,
num_rows
,
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
row_length
,
num_rows
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment