Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1f9c104b
You need to sign in or sign up before continuing.
Commit
1f9c104b
authored
Jun 18, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
2b1428ff
8a03ff34
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
369 additions
and
207 deletions
+369
-207
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+3
-33
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+5
-1
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+6
-2
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+4
-4
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+9
-2
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+342
-165
No files found.
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
1f9c104b
...
...
@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void
compare_scaling_factors_one_dimensional_blocks
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
rows
,
const
size_t
col_blocks
#ifdef __HIP_PLATFORM_AMD__
,
double
atol
=
0.
,
double
rtol
=
0.
#endif
)
{
const
size_t
col_blocks
)
{
const
size_t
test_stride
=
scale_align_stride
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
+
test_stride
*
j
;
const
int
ref_idx
=
i
+
rows
*
j
;
#ifdef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test
[
test_idx
]));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref
[
ref_idx
]));
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
ASSERT_FALSE
(
mismatch
)
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
t
<<
" vs "
<<
r
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#else
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#endif
}
}
}
...
...
@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float
atol
=
0.0
;
float
rtol
=
0.0
;
#ifdef __HIP_PLATFORM_AMD__
double
atol_scale
=
0.0
;
double
rtol_scale
=
0.0
;
if
(
itype
==
DType
::
kFloat32
)
{
atol_scale
=
1e-5
;
}
#endif
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
rows
,
blocks_x
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv
.
get
(),
rows
,
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
);
}
}
...
...
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
1f9c104b
...
...
@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
if
quant_dtype
==
torch
.
int8
:
qx
=
torch
.
round
(
qx
)
positive_mask
=
qx
>=
0
negative_mask
=
~
positive_mask
pos_part
=
torch
.
where
(
positive_mask
,
torch
.
floor
(
qx
+
0.5
),
0
)
neg_part
=
torch
.
where
(
negative_mask
,
torch
.
ceil
(
qx
-
0.5
),
0
)
qx
=
pos_part
+
neg_part
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
...
...
tests/pytorch/references/quantize_scale_calc.py
View file @
1f9c104b
...
...
@@ -4,7 +4,7 @@
from
typing
import
Tuple
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
scale_from_amax_tensor
(
x_dtype
:
torch
.
dtype
,
...
...
@@ -48,6 +48,10 @@ def scale_from_amax_tensor(
# No subnormals and zero.
assert
(
exp
>
-
127
).
all
()
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
if
IS_HIP_EXTENSION
:
host_scale
=
torch
.
ldexp
(
unity
.
cpu
(),
exp
.
cpu
())
scale
=
host_scale
.
to
(
exp
.
device
)
else
:
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
1f9c104b
...
...
@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask
=
torch
.
ones
(
...
...
@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx
=
sx
*
scale_mask
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
assert
qx_t
is
not
None
...
...
@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx_t
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx_t
=
sx_t
*
scale_mask
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
2.5e-1
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
else
:
# should be None
assert
qx_t
is
None
and
qx_t_ref
is
None
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
1f9c104b
...
...
@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
OType
scaled_elt
=
0
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
scaled_elt
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
))));
}
else
{
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
}
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
1f9c104b
...
...
@@ -27,90 +27,6 @@
#include "common/utils.cuh"
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
__device__
bool
is_little_endian
()
{
int
num
=
1
;
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
num
);
if
(
*
ptr
==
1
)
{
return
true
;
}
else
{
return
false
;
}
}
struct
BitFloat
{
private:
char
data
[
3
];
public:
__device__
BitFloat
(
const
float
val
,
bool
pow2scale
)
{
uint32_t
raw_val
=
*
reinterpret_cast
<
const
uint32_t
*>
(
&
val
);
if
(
~
raw_val
&
0x7f800000
)
{
if
(
pow2scale
&&
(
raw_val
&
0x000000FF
))
{
raw_val
|=
0x100
;
}
else
{
raw_val
+=
0x7f
+
((
raw_val
>>
8
)
&
1
);
}
}
else
if
(
raw_val
&
0xffff
)
{
raw_val
|=
0x100
;
}
raw_val
=
(
raw_val
>>
8
);
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
data
[
0
]
=
ptr
[
0
];
data
[
1
]
=
ptr
[
1
];
data
[
2
]
=
ptr
[
2
];
}
else
{
data
[
0
]
=
ptr
[
1
];
data
[
1
]
=
ptr
[
2
];
data
[
2
]
=
ptr
[
3
];
}
}
__device__
operator
float
()
const
{
uint32_t
raw_val
=
0
;
char
*
ptr
=
reinterpret_cast
<
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
ptr
[
1
]
=
data
[
0
];
ptr
[
2
]
=
data
[
1
];
ptr
[
3
]
=
data
[
2
];
}
else
{
ptr
[
0
]
=
data
[
0
];
ptr
[
1
]
=
data
[
1
];
ptr
[
2
]
=
data
[
2
];
}
return
*
reinterpret_cast
<
const
float
*>
(
&
raw_val
);
}
};
struct
BitFloat2
{
BitFloat
u
;
BitFloat
v
;
};
template
<
>
struct
BytesToType
<
6
>
{
using
Type
=
BitFloat2
;
static_assert
(
sizeof
(
Type
)
==
6
);
};
#endif
namespace
{
using
transformer_engine
::
detail
::
FP8BlockwiseColumnwiseOption
;
...
...
@@ -278,12 +194,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
};
extern
__shared__
char
smem_base
[];
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemVec
=
Vec
<
std
::
conditional_t
<
std
::
is_same_v
<
IType
,
float
>
,
BitFloat
,
IType
>
,
kNVecSMem
>
;
HipSMemVec
*
smem
=
reinterpret_cast
<
HipSMemVec
*>
(
&
smem_base
[
0
]);
#else
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
#endif
// Step 1: Load input to shared memory
{
...
...
@@ -317,23 +228,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
]
=
BitFloat
(
input_vec
.
smem_type
.
data
.
elt
[
i
].
data
.
elt
[
j
],
pow_2_scaling
);
}
}
else
{
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
#else
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
#endif
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
...
...
@@ -374,22 +270,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem_vec
[
i
].
data
.
elt
[
j
]
=
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
];
}
}
else
{
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#else
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
#endif
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
...
...
@@ -405,7 +286,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
...
...
@@ -438,11 +320,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
...
...
@@ -494,22 +375,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem_vec
[
i
].
data
.
elt
[
j
]
=
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
];
}
}
else
{
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#else
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
#endif
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
...
...
@@ -523,7 +389,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
...
...
@@ -554,11 +421,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
...
...
@@ -679,6 +546,288 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
#ifdef __HIP_PLATFORM_AMD__
constexpr
int
kFP32SMemCol
=
kTileDim
/
kNVecSMem
;
constexpr
int
kFP32SMemSize
=
kSMemRow
*
kFP32SMemCol
*
kNVecSMem
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel_fp32
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_columnwise_transpose
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
smem_base
);
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Column Major Store
smem
[
c
*
kTileDim
+
r
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Column Major Read
smem_vec
[
i
]
=
smem
[
c
*
kTileDim
+
r
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
CType
scale
;
// Step 2.4: Compute scale
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
}
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_transpose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
num_rows
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_t
[
r_g
*
num_rows
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
// Column Major Read
smem_vec
[
i
]
=
smem
[
c
*
kTileDim
+
r
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
// Step 3.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
#endif
}
// namespace
}
// namespace transformer_engine
...
...
@@ -767,23 +916,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemType
=
std
::
conditional_t
<
std
::
is_same_v
<
InputType
,
float
>
,
BitFloat
,
InputType
>
;
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
HipSMemType
);
#else
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
InputType
,
float
>
)
{
size_t
smem_bytes
=
kFP32SMemSize
*
sizeof
(
InputType
);
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
((
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel_fp32
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel_fp32
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
}
else
{
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
#endif
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
#ifdef __HIP_PLATFORM_AMD__
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
}
#else
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
#endif
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
...
...
@@ -793,7 +968,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);)
// kAligned
columnwise_option
,
pow2_scale
);
#endif
)
// kAligned
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment