Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1f9c104b
Commit
1f9c104b
authored
Jun 18, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
2b1428ff
8a03ff34
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
369 additions
and
207 deletions
+369
-207
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+3
-33
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+5
-1
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+6
-2
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+4
-4
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+9
-2
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+342
-165
No files found.
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
1f9c104b
...
...
@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void
compare_scaling_factors_one_dimensional_blocks
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
rows
,
const
size_t
col_blocks
#ifdef __HIP_PLATFORM_AMD__
,
double
atol
=
0.
,
double
rtol
=
0.
#endif
)
{
const
size_t
col_blocks
)
{
const
size_t
test_stride
=
scale_align_stride
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
+
test_stride
*
j
;
const
int
ref_idx
=
i
+
rows
*
j
;
#ifdef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test
[
test_idx
]));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref
[
ref_idx
]));
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
ASSERT_FALSE
(
mismatch
)
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
t
<<
" vs "
<<
r
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#else
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#endif
}
}
}
...
...
@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float
atol
=
0.0
;
float
rtol
=
0.0
;
#ifdef __HIP_PLATFORM_AMD__
double
atol_scale
=
0.0
;
double
rtol_scale
=
0.0
;
if
(
itype
==
DType
::
kFloat32
)
{
atol_scale
=
1e-5
;
}
#endif
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
rows
,
blocks_x
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv
.
get
(),
rows
,
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
);
}
}
...
...
tests/pytorch/references/blockwise_quantizer_reference.py
View file @
1f9c104b
...
...
@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
if
quant_dtype
==
torch
.
int8
:
qx
=
torch
.
round
(
qx
)
positive_mask
=
qx
>=
0
negative_mask
=
~
positive_mask
pos_part
=
torch
.
where
(
positive_mask
,
torch
.
floor
(
qx
+
0.5
),
0
)
neg_part
=
torch
.
where
(
negative_mask
,
torch
.
ceil
(
qx
-
0.5
),
0
)
qx
=
pos_part
+
neg_part
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
...
...
tests/pytorch/references/quantize_scale_calc.py
View file @
1f9c104b
...
...
@@ -4,7 +4,7 @@
from
typing
import
Tuple
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
scale_from_amax_tensor
(
x_dtype
:
torch
.
dtype
,
...
...
@@ -48,6 +48,10 @@ def scale_from_amax_tensor(
# No subnormals and zero.
assert
(
exp
>
-
127
).
all
()
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
if
IS_HIP_EXTENSION
:
host_scale
=
torch
.
ldexp
(
unity
.
cpu
(),
exp
.
cpu
())
scale
=
host_scale
.
to
(
exp
.
device
)
else
:
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
1f9c104b
...
...
@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask
=
torch
.
ones
(
...
...
@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx
=
sx
*
scale_mask
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
assert
qx_t
is
not
None
...
...
@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx_t
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx_t
=
sx_t
*
scale_mask
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
2.5e-1
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
else
:
# should be None
assert
qx_t
is
None
and
qx_t_ref
is
None
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
1f9c104b
...
...
@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
OType
scaled_elt
=
0
;
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
scaled_elt
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
))));
}
else
{
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
}
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
1f9c104b
...
...
@@ -27,90 +27,6 @@
#include "common/utils.cuh"
namespace
transformer_engine
{
#ifdef __HIP_PLATFORM_AMD__
__device__
bool
is_little_endian
()
{
int
num
=
1
;
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
num
);
if
(
*
ptr
==
1
)
{
return
true
;
}
else
{
return
false
;
}
}
struct
BitFloat
{
private:
char
data
[
3
];
public:
__device__
BitFloat
(
const
float
val
,
bool
pow2scale
)
{
uint32_t
raw_val
=
*
reinterpret_cast
<
const
uint32_t
*>
(
&
val
);
if
(
~
raw_val
&
0x7f800000
)
{
if
(
pow2scale
&&
(
raw_val
&
0x000000FF
))
{
raw_val
|=
0x100
;
}
else
{
raw_val
+=
0x7f
+
((
raw_val
>>
8
)
&
1
);
}
}
else
if
(
raw_val
&
0xffff
)
{
raw_val
|=
0x100
;
}
raw_val
=
(
raw_val
>>
8
);
const
char
*
ptr
=
reinterpret_cast
<
const
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
data
[
0
]
=
ptr
[
0
];
data
[
1
]
=
ptr
[
1
];
data
[
2
]
=
ptr
[
2
];
}
else
{
data
[
0
]
=
ptr
[
1
];
data
[
1
]
=
ptr
[
2
];
data
[
2
]
=
ptr
[
3
];
}
}
__device__
operator
float
()
const
{
uint32_t
raw_val
=
0
;
char
*
ptr
=
reinterpret_cast
<
char
*>
(
&
raw_val
);
if
(
is_little_endian
())
{
ptr
[
1
]
=
data
[
0
];
ptr
[
2
]
=
data
[
1
];
ptr
[
3
]
=
data
[
2
];
}
else
{
ptr
[
0
]
=
data
[
0
];
ptr
[
1
]
=
data
[
1
];
ptr
[
2
]
=
data
[
2
];
}
return
*
reinterpret_cast
<
const
float
*>
(
&
raw_val
);
}
};
struct
BitFloat2
{
BitFloat
u
;
BitFloat
v
;
};
template
<
>
struct
BytesToType
<
6
>
{
using
Type
=
BitFloat2
;
static_assert
(
sizeof
(
Type
)
==
6
);
};
#endif
namespace
{
using
transformer_engine
::
detail
::
FP8BlockwiseColumnwiseOption
;
...
...
@@ -278,12 +194,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
};
extern
__shared__
char
smem_base
[];
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemVec
=
Vec
<
std
::
conditional_t
<
std
::
is_same_v
<
IType
,
float
>
,
BitFloat
,
IType
>
,
kNVecSMem
>
;
HipSMemVec
*
smem
=
reinterpret_cast
<
HipSMemVec
*>
(
&
smem_base
[
0
]);
#else
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
#endif
// Step 1: Load input to shared memory
{
...
...
@@ -317,23 +228,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
]
=
BitFloat
(
input_vec
.
smem_type
.
data
.
elt
[
i
].
data
.
elt
[
j
],
pow_2_scaling
);
}
}
else
{
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
#else
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
#endif
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
...
...
@@ -374,22 +270,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem_vec
[
i
].
data
.
elt
[
j
]
=
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
];
}
}
else
{
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#else
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
#endif
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
...
...
@@ -405,7 +286,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
...
...
@@ -438,11 +320,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
...
...
@@ -494,22 +375,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
smem_vec
[
i
].
data
.
elt
[
j
]
=
smem
[
r
*
kSMemCol
+
c
].
data
.
elt
[
j
];
}
}
else
{
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#else
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
#endif
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
...
...
@@ -523,7 +389,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
...
...
@@ -554,11 +421,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
...
...
@@ -679,6 +546,288 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
#ifdef __HIP_PLATFORM_AMD__
constexpr
int
kFP32SMemCol
=
kTileDim
/
kNVecSMem
;
constexpr
int
kFP32SMemSize
=
kSMemRow
*
kFP32SMemCol
*
kNVecSMem
;
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel_fp32
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_columnwise_transpose
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
smem_base
);
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Column Major Store
smem
[
c
*
kTileDim
+
r
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
// Column Major Read
smem_vec
[
i
]
=
smem
[
c
*
kTileDim
+
r
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
CType
scale
;
// Step 2.4: Compute scale
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
}
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_transpose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
num_rows
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_t
[
r_g
*
num_rows
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers - Column Major
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
// Column Major Read
smem_vec
[
i
]
=
smem
[
c
*
kTileDim
+
r
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
// Step 3.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
const
float
other_amax
=
__shfl_down_sync
((
unsigned
long
long
)(
mask
),
amax
,
delta
,
kThreadsPerWarp
);
#else
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
#ifdef __HIP_PLATFORM_AMD__
amax
=
__shfl_sync
((
unsigned
long
long
)(
mask
),
amax
,
src_lane
,
kThreadsPerWarp
);
#else
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
#endif
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
if
constexpr
(
std
::
is_same_v
<
OType
,
int8_t
>
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
))));
}
else
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
#endif
}
// namespace
}
// namespace transformer_engine
...
...
@@ -767,23 +916,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
#ifdef __HIP_PLATFORM_AMD__
using
HipSMemType
=
std
::
conditional_t
<
std
::
is_same_v
<
InputType
,
float
>
,
BitFloat
,
InputType
>
;
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
HipSMemType
);
#else
#ifdef __HIP_PLATFORM_AMD__
if
constexpr
(
std
::
is_same_v
<
InputType
,
float
>
)
{
size_t
smem_bytes
=
kFP32SMemSize
*
sizeof
(
InputType
);
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
((
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel_fp32
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel_fp32
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
}
else
{
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
#endif
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
#ifdef __HIP_PLATFORM_AMD__
cudaError_t
err
=
cudaFuncSetAttribute
(
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
(
const
void
*
)
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
}
#else
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
#endif
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
...
...
@@ -793,7 +968,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);)
// kAligned
columnwise_option
,
pow2_scale
);
#endif
)
// kAligned
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment