Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
fca88163
Commit
fca88163
authored
Sep 11, 2025
by
wenjh
Browse files
[Perf] blockwise 1d better perf
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
ca1e98b6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
246 additions
and
145 deletions
+246
-145
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+246
-145
No files found.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
fca88163
...
...
@@ -561,8 +561,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
...
...
@@ -654,6 +654,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
__syncthreads
();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
!
return_columnwise_gemm_ready
&&
!
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
...
...
@@ -760,6 +768,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
return_columnwise_gemm_ready
||
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_gemm_ready
)
{
constexpr
int
c_stride
=
...
...
@@ -883,7 +899,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
"num_iterations should be 1 for columnwise non-transpose case"
);
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
if
(
warp_idx
>=
kNumColBlocks
)
{
if
(
warp_idx
>=
kNumColBlocks
)
{
return
;
// No work to do
}
const
int
r_s
=
thr_idx_in_warp
*
kThreadTileRow
;
// Row in shared memory
...
...
@@ -956,8 +972,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
...
...
@@ -979,44 +995,50 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
#ifdef __HIP_PLATFORM_AMD__
constexpr
int
kFP32SMemCol
=
kTileDim
/
kNVecSMem
;
constexpr
int
kFP32SMemSize
=
kSMemRow
*
kFP32SMemCol
*
kNVecSMem
;
constexpr
size_t
kThreadsPerWarp_blocklen_128
=
64
;
constexpr
int
kTileDim64_Rowwise
=
64
;
constexpr
int
kNVecSMem_Rowwise
=
4
;
// The number of elements each LDS/STS touches
constexpr
int
kThreadsPerBlock_Rowwise
=
512
;
// Thread block size, 8 warps in total
constexpr
int
kSMemRow_Rowwise
=
kTileDim64_Rowwise
;
constexpr
int
kSMemCol_Rowwise
=
(
kTileDim
/
kNVecSMem_Rowwise
);
constexpr
int
kSMemSize_Rowwise
=
kSMemRow_Rowwise
*
kSMemCol_Rowwise
*
kNVecSMem_Rowwise
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel_fp32
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
__global__
void
__launch_bounds__
(
kThreadsPerBlock_Rowwise
)
block_scaled_1d_cast_transpose_kernel_rowwise
(
const
IType
*
const
input
,
OType
*
const
output_c
,
CType
*
const
tile_scales_inv_c
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
;
bool
return_columnwise_gemm_ready
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
bool
return_columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
_Rowwise
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
_Rowwise
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
smem_base
);
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
constexpr
int
r_stride
=
kThreadsPerBlock_Rowwise
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim64_Rowwise
/
r_stride
;
//64/16=4
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
_Rowwise
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem_Rowwise
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64_Rowwise
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
const
size_t
num_ele
=
c_g
<
row_length
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
...
...
@@ -1032,13 +1054,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory -
Column
Major
// Step 1.2: Write to shared memory -
row
Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
_Rowwise
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
//
Column
Major Store
smem
[
c
*
k
TileDim
+
r
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
//
row
Major Store
smem
[
r
*
k
SMemCol_Rowwise
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
...
...
@@ -1054,63 +1076,62 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
kThreadsPerBlock
_Rowwise
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
64_Rowwise
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
_Rowwise
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem_Rowwise
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim64_Rowwise
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
const
size_t
num_ele
=
c_g
<
row_length
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp_blocklen_128
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
_Rowwise
];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
_Rowwise
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Column Major Read
smem_vec
[
i
]
=
smem
[
c
*
k
TileDim
+
r
];
smem_vec
[
i
]
=
smem
[
r
*
k
SMemCol_Rowwise
+
c
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
_Rowwise
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNVecSMem
_Rowwise
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
//const float other_amax =__shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const
float
other_amax
=
__shfl_xor
(
amax
,
delta
,
kThreadsPerWarp_blocklen_128
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
CType
scale
;
// Step 2.4: Compute scale
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
...
@@ -1121,21 +1142,21 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
64_Rowwise
+
r_s
;
//
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
_Rowwise
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNVecSMem
_Rowwise
;
++
j
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
output_vec
.
data
.
elt
[
i
*
kNVecSMem
_Rowwise
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
output_vec
.
data
.
elt
[
i
*
kNVecSMem
_Rowwise
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
}
...
...
@@ -1156,28 +1177,109 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
}
constexpr
int
kTileDim64_Colwise
=
64
;
constexpr
int
kNVecSMem_Colwise
=
2
;
constexpr
int
kSMemRow_Colwise
=
kTileDim
;
constexpr
int
kSMemCol_Colwise
=
(
kTileDim64_Colwise
/
kNVecSMem_Colwise
);
constexpr
int
kSMemSize_Colwise
=
kSMemRow_Colwise
*
kSMemCol_Colwise
*
kNVecSMem_Colwise
;
constexpr
int
kNumThreadsLoad_Colwise
=
kTileDim64_Colwise
/
kNVecIn
;
constexpr
int
kNumThreadsStore_Colwise
=
kTileDim
/
kNVecOut
;
constexpr
int
kThreadsPerBlock_Colwise
=
256
;
constexpr
int
kNumWarps_Colwise
=
kThreadsPerBlock_Colwise
/
kThreadsPerWarp_blocklen_128
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock_Colwise
)
block_scaled_1d_cast_transpose_kernel_colwise
(
const
IType
*
const
input
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_columnwise_gemm_ready
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
bool
return_columnwise_compact
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem_Colwise
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem_Colwise
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
smem_base
);
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock_Colwise
/
kNumThreadsLoad_Colwise
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad_Colwise
)
*
(
kNVecIn
/
kNVecSMem_Colwise
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad_Colwise
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64_Colwise
+
c_s
*
kNVecSMem_Colwise
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory - Row Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem_Colwise
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Row Major Store
smem
[
r
*
kSMemCol_Colwise
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_gemm_ready
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
kThreadsPerBlock
_Colwise
/
kNumThreadsStore
_Colwise
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
64_Colwise
/
(
c_stride
*
kNVecSMem
_Colwise
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
_Colwise
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
_Colwise
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64_Colwise
+
c_s
*
kNVecSMem_Colwise
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
num_rows
-
c_g
)
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
_Colwise
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
num_rows
?
std
::
min
(
static_cast
<
size_t
>
(
kNVecOut
),
num_rows
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_t
[
r_g
*
num_rows
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp_blocklen_128
)
/
kNumThreadsStore_Colwise
*
kNumThreadsStore_Colwise
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
_Colwise
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
_Colwise
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
...
...
@@ -1186,11 +1288,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
//
Column
Major Read
smem_vec
[
i
]
=
smem
[
c
*
k
TileDim
+
r
];
//
Row
Major Read
smem_vec
[
i
]
=
smem
[
r
*
k
SMemCol_Colwise
+
c
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
_Colwise
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
...
...
@@ -1199,22 +1301,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
// Step 3.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
for
(
int
delta
=
kNumThreadsStore_Colwise
/
2
;
delta
>
0
;
delta
/=
2
)
{
// const float other_amax =
// __shfl_xor_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const
float
other_amax
=
__shfl_xor
(
amax
,
delta
,
kThreadsPerWarp_blocklen_128
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
...
@@ -1225,7 +1321,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64_Colwise
+
c_s
*
kNVecSMem_Colwise
+
smem_idx
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
...
...
@@ -1255,34 +1352,33 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
output_g
+=
stride_g
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
r_g
+=
c_stride
*
kNVecSMem
_Colwise
;
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if
(
return_columnwise_compact
)
{
// thread tile should be 4x16, 16 means 8 smem reads
constexpr
int
kThreadTileRow
=
kTileDim
/
kThreadsPerWarp
;
constexpr
int
kThreadTileRow
=
kTileDim
/
kThreadsPerWarp
_blocklen_128
;
constexpr
int
kThreadTileCol
=
kNVecOut
;
using
RegVec
=
Vec
<
IType
,
kThreadTileCol
>
;
using
RegScaleVec
=
Vec
<
CType
,
kThreadTileCol
>
;
constexpr
int
num_smem_reads
=
kNVecOut
/
kNVecSMem
;
using
RegScaleVec
=
Vec
<
CType
,
kThreadTileCol
>
;
//float,16
constexpr
int
num_smem_reads
=
kNVecOut
/
kNVecSMem
_Colwise
;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr
int
num_iterations
=
kTileDim
/
(
kNumWarps
*
kThreadTileCol
);
// should be only one iteration
kTileDim
64_Colwise
/
(
kNumWarps
_Colwise
*
kThreadTileCol
);
// should be only one iteration
static_assert
(
num_iterations
==
1
,
"num_iterations should be 1 for columnwise non-transpose case"
);
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
const
int
thr_idx_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
_blocklen_128
;
const
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
_blocklen_128
;
const
int
r_s
=
thr_idx_in_warp
*
kThreadTileRow
;
// Row in shared memory
int
c_s
=
warp_idx
*
num_smem_reads
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64_Colwise
+
c_s
*
kNVecSMem
_Colwise
;
// Column in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kThreadTileCol
),
row_length
-
c_g
)
?
std
::
min
(
static_cast
<
size_t
>
(
kThreadTileCol
),
row_length
-
c_g
)
:
0
;
// For not aligned case
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
...
...
@@ -1296,11 +1392,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
j
=
0
;
j
<
num_smem_reads
;
++
j
)
{
int
c
=
c_s
+
j
;
SMemVec
smem_vec
=
smem
[
c
*
k
TileDim
+
r
];
SMemVec
smem_vec
=
smem
[
r
*
k
SMemCol_Colwise
+
c
];
// copy smem_vec to reg vec with its elements
#pragma unroll
for
(
int
k
=
0
;
k
<
kNVecSMem
;
++
k
)
{
reg_vec
[
i
].
data
.
elt
[
j
*
kNVecSMem
+
k
]
=
smem_vec
.
data
.
elt
[
k
];
for
(
int
k
=
0
;
k
<
kNVecSMem
_Colwise
;
++
k
)
{
reg_vec
[
i
].
data
.
elt
[
j
*
kNVecSMem
_Colwise
+
k
]
=
smem_vec
.
data
.
elt
[
k
];
}
}
}
...
...
@@ -1314,13 +1410,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
// Step 3.3: Reduce amax
const
bool
is_src_lane
=
thr_idx_in_warp
==
0
;
amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
constexpr
int
lane_zero
=
0
;
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
0xFFFFFFFF
),
amax
,
lane_zero
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
0xFFFFFFFF
,
amax
,
lane_zero
);
#endif
#pragma unroll
for
(
int
delta
=
kThreadsPerWarp_blocklen_128
/
2
;
delta
>
0
;
delta
/=
2
)
{
// const float other_amax =
// __shfl_xor_sync((unsigned long long)(0xFFFFFFFF), amax, delta, kThreadsPerWarp);
const
float
other_amax
=
__shfl_xor
(
amax
,
delta
,
kThreadsPerWarp_blocklen_128
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
...
...
@@ -1333,7 +1432,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
reg_idx
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim64_Colwise
+
c_s
*
kNVecSMem_Colwise
+
reg_idx
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
...
...
@@ -1346,8 +1446,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kThreadTileCol
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
reg_vec
[
row_idx
].
data
.
elt
[
i
])
*
thr_scale
.
data
.
elt
[
i
]);
...
...
@@ -1471,41 +1571,42 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#ifdef __HIP_PLATFORM_AMD__
while
(
true
)
{
if
(
128
==
block_len
)
{
if
constexpr
(
std
::
is_same_v
<
InputType
,
float
>
)
{
size_t
smem_bytes
=
kFP32SMemSize
*
sizeof
(
InputType
);
if
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
)
{
size_t
smem_bytes
=
kSMemSize_Rowwise
*
sizeof
(
InputType
);
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)(
block_len
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)(
block_len
/
2
));
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel_
fp32
<
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel_
rowwise
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel_
fp32
<
kAligned
,
float
,
InputType
,
block_scaled_1d_cast_transpose_kernel_
rowwise
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
<<<
grid
,
kThreadsPerBlock
_Rowwise
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
)
,
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scal
e
);
}
else
{
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
if
(
smem_bytes
>=
48
*
1024
)
{
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
epsilon
,
rowwise_option
,
pow2_scale
);
}
if
(
columnwise_option
!=
FP8BlockwiseColumnwiseOption
::
NONE
)
{
size_t
smem_bytes
=
kSMemSize_Colwise
*
sizeof
(
InputTyp
e
);
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)(
block_len
/
2
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)(
block_len
)
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
_colwise
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
block_scaled_1d_cast_transpose_kernel_colwise
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
_Colwise
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
scale_
t_
stride_x
,
scale_
t_
stride_y
,
epsilon
,
columnwise_option
,
pow2_scale
);
}
break
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment