Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
331f2fc4
Commit
331f2fc4
authored
Jun 18, 2025
by
wenjh
Browse files
Resolve merge issue from nv of vector blockwise
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
1f9c104b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
6 deletions
+125
-6
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+125
-6
No files found.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
331f2fc4
...
...
@@ -504,7 +504,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const
bool
is_src_lane
=
thr_idx_in_warp
==
0
;
amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
...
@@ -528,9 +532,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
...
...
@@ -558,9 +568,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_columnwise_transpose
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
bool
return_rowwise
=
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
;
bool
return_columnwise_gemm_ready
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
bool
return_columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
...
...
@@ -724,7 +736,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_
transpose
)
{
if
(
return_columnwise_
gemm_ready
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
...
...
@@ -825,6 +837,113 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if
(
return_columnwise_compact
)
{
// thread tile should be 4x16, 16 means 8 smem reads
constexpr
int
kThreadTileRow
=
kTileDim
/
kThreadsPerWarp
;
constexpr
int
kThreadTileCol
=
kNVecOut
;
using
RegVec
=
Vec
<
IType
,
kThreadTileCol
>
;
using
RegScaleVec
=
Vec
<
CType
,
kThreadTileCol
>
;
constexpr
int
num_smem_reads
=
kNVecOut
/
kNVecSMem
;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr
int
num_iterations
=
kTileDim
/
(
kNumWarps
*
kThreadTileCol
);
// should be only one iteration
static_assert
(
num_iterations
==
1
,
"num_iterations should be 1 for columnwise non-transpose case"
);
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
const
int
r_s
=
thr_idx_in_warp
*
kThreadTileRow
;
// Row in shared memory
int
c_s
=
warp_idx
*
num_smem_reads
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kThreadTileCol
),
row_length
-
c_g
)
:
0
;
// For not aligned case
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
RegVec
reg_vec
[
kThreadTileRow
];
RegScaleVec
thr_scale
;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
int
r
=
r_s
+
i
;
#pragma unroll
for
(
int
j
=
0
;
j
<
num_smem_reads
;
++
j
)
{
int
c
=
c_s
+
j
;
SMemVec
smem_vec
=
smem
[
c
*
kTileDim
+
r
];
// copy smem_vec to reg vec with its elements
#pragma unroll
for
(
int
k
=
0
;
k
<
kNVecSMem
;
++
k
)
{
reg_vec
[
i
].
data
.
elt
[
j
*
kNVecSMem
+
k
]
=
smem_vec
.
data
.
elt
[
k
];
}
}
}
#pragma unroll
for
(
int
reg_idx
=
0
;
reg_idx
<
kThreadTileCol
;
++
reg_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileRow
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
reg_vec
[
i
].
data
.
elt
[
reg_idx
]));
}
// Step 3.3: Reduce amax
const
bool
is_src_lane
=
thr_idx_in_warp
==
0
;
amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
thr_scale
.
data
.
elt
[
reg_idx
]
=
scale
;
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
c_g
+
reg_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
reg_idx
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
for
(
int
row_idx
=
0
;
row_idx
<
kThreadTileRow
;
++
row_idx
)
{
OType
*
output_g
=
&
output_t
[(
r_g
+
row_idx
)
*
row_length
+
c_g
];
// Output address in global memory
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
+
row_idx
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment