Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
8be63f64
"server/text_generation_server/models/seq2seq_lm.py" did not exist on "2ad895a6cc530474cae7e24ace1e463018172d0e"
Commit
8be63f64
authored
Dec 22, 2025
by
fengzch
Browse files
PTX 指令替换
parent
d21ab0f5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
231 additions
and
141 deletions
+231
-141
src/kernels/awq/dequantize.cuh
src/kernels/awq/dequantize.cuh
+72
-21
src/kernels/awq/gemm_awq.cu
src/kernels/awq/gemm_awq.cu
+18
-20
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+141
-100
No files found.
src/kernels/awq/dequantize.cuh
View file @
8be63f64
...
...
@@ -81,6 +81,53 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
// 设备端的bfloat16到float转换函数
__device__
float
bf16_to_float_device
(
uint16_t
bf16
)
{
// 将bfloat16转为float:bf16左移16位作为float的高16位
uint32_t
val
=
(
uint32_t
)
bf16
<<
16
;
return
__uint_as_float
(
val
);
}
// 设备端的float到bfloat16转换函数
__device__
uint16_t
float_to_bf16_device
(
float
f
)
{
// 将float转为bfloat16:取float的高16位
uint32_t
float_bits
=
__float_as_uint
(
f
);
// 四舍五入处理
uint32_t
rounding_bias
=
((
float_bits
>>
16
)
&
1
)
+
0x7FFF
;
return
(
uint16_t
)((
float_bits
+
rounding_bias
)
>>
16
);
}
// C++实现的bfloat16x2 FMA函数
__device__
uint32_t
fma_bf16x2_cpp
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
// 解包a、b、c的高低位
uint16_t
a_high
=
(
uint16_t
)(
a
>>
16
);
uint16_t
a_low
=
(
uint16_t
)(
a
&
0xFFFF
);
uint16_t
b_high
=
(
uint16_t
)(
b
>>
16
);
uint16_t
b_low
=
(
uint16_t
)(
b
&
0xFFFF
);
uint16_t
c_high
=
(
uint16_t
)(
c
>>
16
);
uint16_t
c_low
=
(
uint16_t
)(
c
&
0xFFFF
);
// 将bfloat16转换为float进行计算
// 高位计算:(a_high * b_high) + c_high
float
a_high_f
=
bf16_to_float_device
(
a_high
);
float
b_high_f
=
bf16_to_float_device
(
b_high
);
float
c_high_f
=
bf16_to_float_device
(
c_high
);
float
result_high_f
=
a_high_f
*
b_high_f
+
c_high_f
;
uint16_t
result_high
=
float_to_bf16_device
(
result_high_f
);
// 低位计算:(a_low * b_low) + c_low
float
a_low_f
=
bf16_to_float_device
(
a_low
);
float
b_low_f
=
bf16_to_float_device
(
b_low
);
float
c_low_f
=
bf16_to_float_device
(
c_low
);
float
result_low_f
=
a_low_f
*
b_low_f
+
c_low_f
;
uint16_t
result_low
=
float_to_bf16_device
(
result_low_f
);
// 重新打包结果
return
((
uint32_t
)
result_high
<<
16
)
|
result_low
;
}
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
__nv_bfloat162
const
&
source
,
uint4
*
result
)
{
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
...
...
@@ -103,22 +150,27 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
I4s_TO_BF16s_MAGIC_NUM
=
0x43004300
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
>>
4
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
i4s
>>
8
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
i4s
>>
12
),
"n"
(
MASK
),
"n"
(
I4s_TO_BF16s_MAGIC_NUM
),
"n"
(
immLut
));
// // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0])
// : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1])
// : "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2])
// : "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3])
// : "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
h
[
0
]
=
((
i4s
&
MASK
)
|
I4s_TO_BF16s_MAGIC_NUM
);
h
[
1
]
=
(((
i4s
>>
4
)
&
MASK
)
|
I4s_TO_BF16s_MAGIC_NUM
);
h
[
2
]
=
(((
i4s
>>
8
)
&
MASK
)
|
I4s_TO_BF16s_MAGIC_NUM
);
h
[
3
]
=
(((
i4s
>>
12
)
&
MASK
)
|
I4s_TO_BF16s_MAGIC_NUM
);
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
...
...
@@ -134,9 +186,8 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// // Convert elt_67
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h
[
0
]
=
__hsub
(
h
[
0
],
__float2bfloat16_rn
(
128.0
f
));
h
[
1
]
=
__hsub
(
h
[
1
],
__float2bfloat16_rn
(
128.0
f
));
h
[
2
]
=
__hsub
(
h
[
2
],
__float2bfloat16_rn
(
128.0
f
));
h
[
3
]
=
__hsub
(
h
[
3
],
__float2bfloat16_rn
(
128.0
f
));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
h
[
0
]
=
fma_bf16x2_cpp
(
h
[
0
],
BF16_ONE
,
BF16_BIAS
);
h
[
1
]
=
fma_bf16x2_cpp
(
h
[
1
],
BF16_ONE
,
BF16_BIAS
);
h
[
2
]
=
fma_bf16x2_cpp
(
h
[
2
],
BF16_ONE
,
BF16_BIAS
);
h
[
3
]
=
fma_bf16x2_cpp
(
h
[
3
],
BF16_ONE
,
BF16_BIAS
);
}
src/kernels/awq/gemm_awq.cu
View file @
8be63f64
...
...
@@ -92,28 +92,26 @@ template<typename f16_t>
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."
);
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
3
])
:
"r"
(
addr
));
}
template
<
typename
f16_t
>
__inline__
__device__
void
ldmatrix_m8n8_x4_trans_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."
);
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
3
])
:
"r"
(
addr
));
}
__inline__
__device__
void
cp_async_cg_A
(
uint32_t
smem_int_ptr
,
const
uint4
*
__restrict__
src
,
bool
mask
)
{
...
...
@@ -383,10 +381,10 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int
M
,
int
N
,
int
K
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
return
;
#endif
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
trap_unsupported_arch();
//
return;
//
#endif
using
f162_t
=
typename
packed_as
<
f16_t
,
2
>::
type
;
constexpr
int
NUM_WARPS_MN
=
CTA_M
/
WARP_M
*
CTA_N
/
WARP_N
;
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
8be63f64
...
...
@@ -46,37 +46,73 @@ __device__ __forceinline__ static T load(const T *addr) {
return
*
addr
;
}
// template<typename T>
// __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
// if constexpr (sizeof(T) == 4) {
// uint32_t data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// // "@loadpred ld.global.nc.b32 %0, [%1];"
// // "}"
// // : "=r"(data)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 8) {
// uint2 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// // "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// // "}"
// // : "=r"(data.x), "=r"(data.y)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// if constexpr (sizeof(T) == 16) {
// uint4 data;
// // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// // "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// // "}"
// // : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// // : "l"(addr), "r"((int)pred));
// return *reinterpret_cast<T *>(&data);
// }
// T result;
// if (pred) {
// result = *addr;
// }
// return result;
// }
template
<
typename
T
>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// "@loadpred ld.global.nc.b32 %0, [%1];"
//
"}"
//
: "=r"(data)
//
: "l"(addr), "r"((int)pred))
;
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if
(
pred
)
{
const
unsigned
char
*
src
=
reinterpret_cast
<
const
unsigned
char
*>
(
addr
);
unsigned
char
*
dst
=
reinterpret_cast
<
unsigned
char
*>
(
&
data
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
dst
[
i
]
=
src
[
i
]
;
}
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
//
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
//
"}"
//
: "=r"(data.x), "=r"(data.y)
//
: "l"(addr), "r"((int)pred))
;
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if
(
pred
)
{
const
unsigned
char
*
src
=
reinterpret_cast
<
const
unsigned
char
*>
(
addr
);
unsigned
char
*
dst
=
reinterpret_cast
<
unsigned
char
*>
(
&
data
);
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
++
i
)
dst
[
i
]
=
src
[
i
]
;
}
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// "}"
// : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
if
(
pred
)
{
const
unsigned
char
*
src
=
reinterpret_cast
<
const
unsigned
char
*>
(
addr
);
unsigned
char
*
dst
=
reinterpret_cast
<
unsigned
char
*>
(
&
data
);
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
++
i
)
dst
[
i
]
=
src
[
i
];
}
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
...
...
@@ -92,21 +128,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if
constexpr
(
shmem
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
// asm volatile(
// "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
((
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
// asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
((
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
*
addr
=
val
;
...
...
@@ -115,17 +147,17 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if
constexpr
(
sizeof
(
T
)
==
4
)
{
// __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
*
reinterpret_cast
<
unsigned
int
*>
(
addr
)
=
*
reinterpret_cast
<
unsigned
int
*>
(
&
val
);
*
reinterpret_cast
<
unsigned
int
*>
(
addr
)
=
*
reinterpret_cast
<
unsigned
int
*>
(
&
val
);
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
// __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
*
reinterpret_cast
<
uint2
*>
(
addr
)
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
*
reinterpret_cast
<
uint2
*>
(
addr
)
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
// __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
*
reinterpret_cast
<
uint4
*>
(
addr
)
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
*
reinterpret_cast
<
uint4
*>
(
addr
)
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
return
;
}
*
addr
=
val
;
...
...
@@ -135,39 +167,33 @@ template<typename T>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.b32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
// "}" ::"r"((int)pred),
// "l"(addr),
// "r"(data.x),
// "r"(data.y),
// "r"(data.z),
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
...
...
@@ -229,11 +255,17 @@ template<>
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
false
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
v1
=
__float2int_rn
(
value
.
x
);
v2
=
__float2int_rn
(
value
.
y
);
int
s1
=
max
(
-
8
,
min
(
7
,
v1
));
int
s2
=
max
(
-
8
,
min
(
7
,
v2
));
unsigned
int
u1
=
s1
&
0xF
;
unsigned
int
u2
=
s2
&
0xF
;
result
=
(
u2
<<
4
)
|
u1
;
return
result
;
}
...
...
@@ -241,11 +273,15 @@ template<>
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
true
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
v1
=
__float2int_rn
(
value
.
x
);
v2
=
__float2int_rn
(
value
.
y
);
unsigned
int
u1
=
static_cast
<
unsigned
int
>
(
max
(
0
,
min
(
15
,
v1
)));
unsigned
int
u2
=
static_cast
<
unsigned
int
>
(
max
(
0
,
min
(
15
,
v2
)));
result
=
(
u2
<<
4
)
|
u1
;
return
result
;
}
...
...
@@ -253,21 +289,29 @@ template<>
__device__
__forceinline__
uint32_t
quantize_float2
<
8
,
false
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
//
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
// asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
v1
=
__float2int_rn
(
value
.
x
);
// 等价于 roundf(value.x)
v2
=
__float2int_rn
(
value
.
y
);
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
// 第二步:饱和处理到8位有符号范围 [-128, 127]
int
s1
=
max
(
-
128
,
min
(
127
,
v1
));
int
s2
=
max
(
-
128
,
min
(
127
,
v2
));
// 第三步:将有符号值转换为无符号位模式
// 使用位运算将有符号数转换为8位二进制补码表示
unsigned
int
u1
=
s1
&
0xFF
;
// 只取低8位
unsigned
int
u2
=
s2
&
0xFF
;
result
=
(
u2
<<
8
)
|
u1
;
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
// asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
// : "=r"(result)
// : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
return
result
;
}
...
...
@@ -372,17 +416,15 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
};
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
// asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
}
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
// "}" ::"r"((int)pred),
// "l"(addr),
// "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"f"
(
val
));
}
template
<
int
cnt
,
typename
F
>
...
...
@@ -394,13 +436,15 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
//
float fval;
//
//
fval = (val & 0x7FFFFF) ^ 0x4B400000
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=f"(fval)
// : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
fval
-
12582912.0
f
;
unsigned
int
temp
=
(
val
&
0x7FFFFF
)
^
0x4B400000
;
float
result
;
memcpy
(
&
result
,
&
temp
,
sizeof
(
float
));
return
result
-
12582912.0
f
;
}
template
<
typename
To
,
typename
From
>
...
...
@@ -416,13 +460,12 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t
ival
;
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
ival
=
ival
>>
4
;
// (val & 0x03FF03FF) ^ 0x76007600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
}
// val in [-4096, 4095], steps of 8, round to nearest
...
...
@@ -436,12 +479,11 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t
hval
;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
// (val & 0x03FF03FF) ^ 0x72007200
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
}
// val in [-512, 511]
...
...
@@ -450,12 +492,11 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
// (val & 0x03FF03FF) ^ 0x66006600
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
// : "=r"(hval)
// : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment