Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1613 additions
and
1643 deletions
+1613
-1643
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+149
-160
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+71
-74
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+372
-345
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+46
-42
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+201
-187
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+31
-33
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+39
-42
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+148
-162
src/kernels/zgemm/lora.cuh
src/kernels/zgemm/lora.cuh
+70
-75
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+137
-151
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+156
-202
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+50
-47
src/layernorm.cpp
src/layernorm.cpp
+21
-12
src/layernorm.h
src/layernorm.h
+18
-10
src/pytorch_compat.h
src/pytorch_compat.h
+94
-91
No files found.
src/kernels/zgemm/gemm_utils.cuh
View file @
57e50f8d
...
...
@@ -7,7 +7,7 @@
namespace
nunchaku
::
kernels
{
static
constexpr
int
clamp
(
int
val
,
int
min
,
int
max
)
{
if
(
val
<
min
)
if
(
val
<
min
)
return
min
;
if
(
val
>
max
)
return
max
;
...
...
@@ -15,17 +15,20 @@ static constexpr int clamp(int val, int min, int max) {
}
template
<
bool
shmem
=
false
,
typename
T
>
__device__
__forceinline__
static
T
load
(
const
T
*
addr
)
{
__device__
__forceinline__
static
T
load
(
const
T
*
addr
)
{
if
constexpr
(
shmem
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
__cvta_generic_to_shared
(
addr
)));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
return
*
addr
;
...
...
@@ -44,30 +47,32 @@ static T load(const T *addr) {
}
template
<
typename
T
>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
:
"=r"
(
data
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
:
"=r"
(
data
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
...
...
@@ -79,17 +84,21 @@ static T load_pred(const T *addr, bool pred) {
}
template
<
bool
shmem
=
false
,
typename
T
>
__device__
__forceinline__
static
void
store
(
T
*
addr
,
T
val
)
{
__device__
__forceinline__
static
void
store
(
T
*
addr
,
T
val
)
{
if
constexpr
(
shmem
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
(
__cvta_generic_to_shared
(
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
*
addr
=
val
;
...
...
@@ -107,35 +116,41 @@ static void store(T *addr, T val) {
if
constexpr
(
sizeof
(
T
)
==
16
)
{
__stcg
(
reinterpret_cast
<
uint4
*>
(
addr
),
*
reinterpret_cast
<
uint4
*>
(
&
val
));
return
;
}
}
*
addr
=
val
;
}
template
<
typename
T
>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
...
...
@@ -144,198 +159,174 @@ static void store_pred(T *addr, T val, bool pred) {
}
}
__device__
__forceinline__
static
float2
half22float2
(
half2
val
)
{
__device__
__forceinline__
static
float2
half22float2
(
half2
val
)
{
return
__half22float2
(
val
);
}
__device__
__forceinline__
static
float2
half22float2
(
__nv_bfloat162
val
)
{
__device__
__forceinline__
static
float2
half22float2
(
__nv_bfloat162
val
)
{
return
__bfloat1622float2
(
val
);
}
template
<
typename
T
>
__device__
__forceinline__
static
T
float22half2
(
float2
val
)
=
delete
;
__device__
__forceinline__
static
T
float22half2
(
float2
val
)
=
delete
;
template
<
>
__device__
__forceinline__
half2
float22half2
<
half2
>
(
float2
val
)
{
__device__
__forceinline__
half2
float22half2
<
half2
>
(
float2
val
)
{
return
__float22half2_rn
(
val
);
}
template
<
>
__device__
__forceinline__
__nv_bfloat162
float22half2
<
__nv_bfloat162
>
(
float2
val
)
{
__device__
__forceinline__
__nv_bfloat162
float22half2
<
__nv_bfloat162
>
(
float2
val
)
{
return
__float22bfloat162_rn
(
val
);
}
template
<
typename
T
>
__device__
__forceinline__
static
void
unused_var
(
T
&
val
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
unused_var
(
T
&
val
,
bool
alwaysfalse
)
{
volatile
T
*
ptr
=
nullptr
;
if
(
alwaysfalse
)
{
*
ptr
=
val
;
}
}
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
out
.
x
),
"=r"
(
out
.
y
),
"=r"
(
out
.
z
),
"=r"
(
out
.
w
)
:
"l"
(
__cvta_generic_to_shared
(
ptr
))
);
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
out
.
x
),
"=r"
(
out
.
y
),
"=r"
(
out
.
z
),
"=r"
(
out
.
w
)
:
"l"
(
__cvta_generic_to_shared
(
ptr
)));
}
template
<
typename
T
>
__device__
__forceinline__
static
T
movmatrix
(
T
x
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
))
:
"r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
)));
__device__
__forceinline__
static
T
movmatrix
(
T
x
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
))
:
"r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
)));
return
x
;
}
// x in low bit, y in high bit
template
<
int
bitwidth
,
bool
use_unsigned
>
__device__
__forceinline__
uint32_t
quantize_float2
(
float2
value
)
=
delete
;
__device__
__forceinline__
uint32_t
quantize_float2
(
float2
value
)
=
delete
;
template
<
>
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
false
>
(
float2
value
)
{
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
false
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
return
result
;
}
template
<
>
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
true
>
(
float2
value
)
{
__device__
__forceinline__
uint32_t
quantize_float2
<
4
,
true
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
return
result
;
}
template
<
>
__device__
__forceinline__
uint32_t
quantize_float2
<
8
,
false
>
(
float2
value
)
{
__device__
__forceinline__
uint32_t
quantize_float2
<
8
,
false
>
(
float2
value
)
{
int
v1
,
v2
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
}
__device__
__forceinline__
static
float
cuda_tanhf
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_tanhf
(
float
x
)
{
float
result
;
asm
(
"tanh.approx.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"tanh.approx.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
__device__
__forceinline__
static
float
cuda_frcp
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_frcp
(
float
x
)
{
float
result
;
asm
(
"rcp.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"rcp.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
__device__
__forceinline__
static
float
cuda_frsqrt
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_frsqrt
(
float
x
)
{
float
result
;
asm
(
"rsqrt.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"rsqrt.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
__device__
__forceinline__
static
float
cuda_sin
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_sin
(
float
x
)
{
float
result
;
asm
(
"sin.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"sin.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
__device__
__forceinline__
static
float
cuda_cos
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_cos
(
float
x
)
{
float
result
;
asm
(
"cos.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"cos.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
__device__
__forceinline__
static
float
cuda_exp2
(
float
x
)
{
__device__
__forceinline__
static
float
cuda_exp2
(
float
x
)
{
float
result
;
asm
(
"ex2.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
asm
(
"ex2.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__
__device__
static
float
cuda_sigmoidf
(
float
a
)
{
__forceinline__
__device__
static
float
cuda_sigmoidf
(
float
a
)
{
#if USE_TANH
return
fmaf
(
0.5
,
__tanhf
(
0.5
f
*
a
),
0.5
f
);
#else // USE_TANH
return
fmaf
(
0.5
,
__tanhf
(
0.5
f
*
a
),
0.5
f
);
#else
// USE_TANH
const
float
L2E
=
1.442695041
f
;
// log2(exp(1))
float
t
,
d
,
e
,
r
;
t
=
-
L2E
*
a
;
asm
(
"ex2.approx.ftz.f32 %0,%1;
\n\t
"
:
"=f"
(
e
)
:
"f"
(
t
));
asm
(
"ex2.approx.ftz.f32 %0,%1;
\n\t
"
:
"=f"
(
e
)
:
"f"
(
t
));
d
=
e
+
1.0
f
;
asm
(
"rcp.approx.ftz.f32 %0,%1;
\n\t
"
:
"=f"
(
r
)
:
"f"
(
d
));
asm
(
"rcp.approx.ftz.f32 %0,%1;
\n\t
"
:
"=f"
(
r
)
:
"f"
(
d
));
return
r
;
#endif // USE_TANH
}
template
<
typename
T
>
__device__
__forceinline__
static
T
gelu_half2
(
T
x
)
{
__device__
__forceinline__
static
T
gelu_half2
(
T
x
)
{
float2
xf
=
half22float2
(
x
);
float2
x3f
=
xf
*
xf
*
xf
;
float
t1
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
x
+
(
0.044715
f
*
x3f
.
x
)));
float
t2
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
y
+
(
0.044715
f
*
x3f
.
y
)));
float
t1
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
x
+
(
0.044715
f
*
x3f
.
x
)));
float
t2
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
y
+
(
0.044715
f
*
x3f
.
y
)));
return
float22half2
<
T
>
(
xf
*
make_float2
(
t1
,
t2
));
}
template
<
typename
T
>
__device__
__forceinline__
static
T
gelu_half
(
T
x
)
{
__device__
__forceinline__
static
T
gelu_half
(
T
x
)
{
float
xf
=
float
(
x
);
float
x3f
=
xf
*
xf
*
xf
;
float
t
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
+
(
0.044715
f
*
x3f
)));
float
t
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
+
(
0.044715
f
*
x3f
)));
return
(
T
)(
xf
*
t
);
}
template
<
typename
T
>
__device__
__forceinline__
static
T
silu
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)((
float
)
x
*
cuda_sigmoidf
((
float
)
x
));
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
template
<
typename
T
>
__device__
__forceinline__
static
T
silu
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)((
float
)
x
*
cuda_sigmoidf
((
float
)
x
));
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__
__forceinline__
static
half2
h2div
(
half2
a
,
half2
b
)
{
__device__
__forceinline__
static
half2
h2div
(
half2
a
,
half2
b
)
{
float2
af
=
half22float2
(
a
);
float2
bf
=
half22float2
(
b
);
float2
of
;
...
...
@@ -343,8 +334,7 @@ static half2 h2div(half2 a, half2 b) {
of
.
y
=
__fdividef
(
af
.
y
,
bf
.
y
);
return
float22half2
<
half2
>
(
of
);
};
__device__
__forceinline__
static
__nv_bfloat162
h2div
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
__device__
__forceinline__
static
__nv_bfloat162
h2div
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
af
=
half22float2
(
a
);
float2
bf
=
half22float2
(
b
);
float2
of
;
...
...
@@ -353,41 +343,37 @@ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
return
float22half2
<
__nv_bfloat162
>
(
of
);
};
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
}
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"f"
(
val
));
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"f"
(
val
));
}
template
<
int
cnt
,
typename
F
>
__device__
__forceinline__
static
void
unrolled_loop
(
F
&&
lambda
)
{
auto
call
=
[
&
]
<
int
...
Is
>
(
std
::
integer_sequence
<
int
,
Is
...
>
)
{
(
lambda
.
template
operator
()
<
Is
>(),
...);
};
__device__
__forceinline__
static
void
unrolled_loop
(
F
&&
lambda
)
{
auto
call
=
[
&
]
<
int
...
Is
>
(
std
::
integer_sequence
<
int
,
Is
...
>
)
{
(
lambda
.
template
operator
()
<
Is
>(),
...);
};
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
fval
-
12582912.0
f
;
}
template
<
typename
To
,
typename
From
>
__device__
__forceinline__
static
To
bit_cast
(
const
From
&
input
)
{
__device__
__forceinline__
static
To
bit_cast
(
const
From
&
input
)
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
// not safe but anyway
return
*
reinterpret_cast
<
const
To
*>
(
&
input
);
...
...
@@ -395,20 +381,20 @@ static To bit_cast(const From &input) {
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
__device__
__forceinline__
static
half2
int2half2_fast_8192
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2
int2half2_fast_8192
(
int
x
,
int
y
)
{
uint32_t
ival
;
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
ival
=
ival
>>
4
;
// (val & 0x03FF03FF) ^ 0x76007600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
}
// val in [-4096, 4095], steps of 8, round to nearest
__device__
__forceinline__
static
half2
int2half2_fast_4096_rn
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2
int2half2_fast_4096_rn
(
int
x
,
int
y
)
{
// x = max(min(x, 4095), -4096);
// y = max(min(y, 4095), -4096);
// TODO: round to even?
...
...
@@ -416,24 +402,27 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
y
=
y
*
8192
+
32768
;
uint32_t
ival
;
uint32_t
hval
;
// ival.lo = x.hi; ival.hi = y.hi;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
// (val & 0x03FF03FF) ^ 0x72007200
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
}
// val in [-512, 511]
__device__
__forceinline__
static
half2
int2half2_fast_512
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2
int2half2_fast_512
(
int
x
,
int
y
)
{
uint32_t
ival
;
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
// (val & 0x03FF03FF) ^ 0x66006600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
}
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4.cu
View file @
57e50f8d
...
...
@@ -14,7 +14,6 @@ struct FasterI2FMode {
static
bool
check
(
bool
act_unsigned
);
};
template
<
typename
F
>
static
void
invoke_launch
(
Tensor
::
ScalarType
dtype
,
bool
use_fp4
,
bool
fasterI2F
,
F
&&
launch
)
{
if
(
fasterI2F
&&
dtype
==
Tensor
::
FP16
)
{
...
...
@@ -32,37 +31,35 @@ static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F
}
}
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
Tensor
::
ScalarType
dtype
=
Tensor
::
INVALID_SCALAR_TYPE
;
if
(
!
fp4
)
{
dtype
=
ascales
.
dtype
();
...
...
@@ -75,37 +72,35 @@ void gemm_w4a4(
}
}
invoke_launch
(
dtype
,
fp4
,
FasterI2FMode
::
check
(
act_unsigned
),
[
&
]
<
typename
Config
,
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
gemm_w4a4
(
act
,
wgt
,
out
,
qout
,
ascales
,
wscales
,
oscales
,
poolout
,
lora_act_in
,
lora_up
,
lora_down
,
lora_act_out
,
norm_q
,
norm_k
,
rotary_emb
,
bias
,
smooth_factor
,
out_vk
,
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
wcscales
,
out_q
,
out_k
,
out_v
,
attn_tokens
);
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
gemm_w4a4
(
act
,
wgt
,
out
,
qout
,
ascales
,
wscales
,
oscales
,
poolout
,
lora_act_in
,
lora_up
,
lora_down
,
lora_act_out
,
norm_q
,
norm_k
,
rotary_emb
,
bias
,
smooth_factor
,
out_vk
,
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
wcscales
,
out_q
,
out_k
,
out_v
,
attn_tokens
);
});
}
...
...
@@ -115,26 +110,28 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
invoke_launch
(
input
.
dtype
(),
fp4
,
false
,
[
&
]
<
typename
Config
,
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
});
}
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
false
,
false
,
[
&
]
<
typename
Config
,
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_act
(
input
,
output
,
oscales
);
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_act
(
input
,
output
,
oscales
);
});
}
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
false
,
false
,
[
&
]
<
typename
Config
,
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_wgt
(
input
,
output
,
oscales
);
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_wgt
(
input
,
output
,
oscales
);
});
}
...
...
@@ -143,7 +140,7 @@ bool FasterI2FMode::check(bool act_unsigned) {
if
(
prop
->
major
!=
7
||
prop
->
minor
!=
5
)
{
return
false
;
}
if
(
mode
==
Always
)
{
return
true
;
}
else
if
(
mode
==
Enabled
&&
!
act_unsigned
)
{
...
...
@@ -162,4 +159,4 @@ void set_faster_i2f_mode(std::string mode) {
FasterI2FMode
::
mode
=
mapping
.
at
(
mode
);
}
};
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4.cuh
View file @
57e50f8d
...
...
@@ -31,8 +31,7 @@ public:
static
constexpr
bool
FP4_AVAILABLE
=
false
;
#endif
__device__
__forceinline__
static
void
trap_no_fp4
()
{
__device__
__forceinline__
static
void
trap_no_fp4
()
{
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"FP4 is not available on this device
\n
"
);
}
...
...
@@ -44,12 +43,12 @@ public:
static_assert
(
WARP_N
%
32
==
0
);
static_assert
(
WARP_M
%
32
==
0
);
static
constexpr
int
WMSCALES_PACK_SIZE
=
clamp
(
WARP_N
/
32
,
1
,
4
);
static
constexpr
int
WMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_N
/
32
,
WMSCALES_PACK_SIZE
);
static
constexpr
int
WMSCALES_PACK_SIZE
=
clamp
(
WARP_N
/
32
,
1
,
4
);
static
constexpr
int
WMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_N
/
32
,
WMSCALES_PACK_SIZE
);
static
constexpr
int
WMSCALES_VALID_LANES
=
WARP_SIZE
;
static
constexpr
int
AMSCALES_PACK_SIZE
=
clamp
(
WARP_M
/
32
,
1
,
4
);
static
constexpr
int
AMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_M
/
32
,
AMSCALES_PACK_SIZE
);
static
constexpr
int
AMSCALES_PACK_SIZE
=
clamp
(
WARP_M
/
32
,
1
,
4
);
static
constexpr
int
AMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_M
/
32
,
AMSCALES_PACK_SIZE
);
static
constexpr
int
AMSCALES_VALID_LANES
=
WARP_SIZE
;
struct
packed_wmscale_t
{
...
...
@@ -62,48 +61,50 @@ public:
using
wmscale_warp
=
std
::
array
<
packed_wmscale_t
,
WMSCALES_NUM_PACKS
>
;
// amscales: [M / BLOCK_M, K / group size, NUM_WARPS, AMSCALES_NUM_PACKS, WARP_SIZE] of packed_amscale_t
__device__
__forceinline__
static
void
load_amscale
(
const
packed_amscale_t
*
ptr
,
int
group
,
amscale_warp
&
out
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_amscale
(
const
packed_amscale_t
*
ptr
,
int
group
,
amscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
AMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
+
i
*
AMSCALES_VALID_LANES
+
laneId
],
pred
);
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
+
i
*
AMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
// wmscales: [N / BLOCK_N, 1, K / group size, WMSCALES_NUM_PACKS, WMSCALES_VALID_LANES] of packed_wmscale_t
__device__
__forceinline__
static
void
load_wmscale
(
const
packed_wmscale_t
*
ptr
,
int
group
,
wmscale_warp
&
out
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_wmscale
(
const
packed_wmscale_t
*
ptr
,
int
group
,
wmscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
WMSCALES_NUM_PACKS
+
i
)
*
WMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
__device__
__forceinline__
static
void
quantize_w4a4_fp4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
uint32_t
&
output_scale
,
int
ida
)
{
__device__
__forceinline__
static
void
quantize_w4a4_fp4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
uint32_t
&
output_scale
,
int
ida
)
{
constexpr
int
NUM_GROUPS
=
4
;
static_assert
(
NUM_GROUPS
==
INSN_K
/
INSN_N
);
constexpr
float
QVALUE_MAX
=
6.0
f
;
constexpr
float
QVALUE_MAX
=
6.0
f
;
constexpr
float
RECPI_QVALUE_MAX
=
1
/
QVALUE_MAX
;
constexpr
float
MSCALE_MAX
=
448.0
f
;
constexpr
float
MSCALE_MAX
=
448.0
f
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// 0 for row 0-7; 1 for row 8-15
// each half2_t represents a 8*8 matrix
half2_t
input
[
2
][
INSN_K
/
INSN_N
*
2
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_N
;
i
++
)
{
input
[
0
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
0
];
input
[
0
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
2
];
input
[
1
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
1
];
input
[
1
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
3
];
}
auto
maxabs
=
[](
half2_t
val
)
ALWAYSINLINE
{
val
=
__habs2
(
val
);
return
__hmax
(
val
.
x
,
val
.
y
);
...
...
@@ -111,14 +112,14 @@ public:
// each half_t represents maxvalue in a 8*16 matrix
half_t
maxvalue
[
2
][
NUM_GROUPS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxabs
(
input
[
0
][
i
*
2
]),
maxabs
(
input
[
0
][
i
*
2
+
1
]));
maxvalue
[
1
][
i
]
=
__hmax
(
maxabs
(
input
[
1
][
i
*
2
]),
maxabs
(
input
[
1
][
i
*
2
+
1
]));
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxvalue
[
0
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
0
][
i
],
mask
));
maxvalue
[
1
][
i
]
=
__hmax
(
maxvalue
[
1
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
1
][
i
],
mask
));
...
...
@@ -128,10 +129,10 @@ public:
float
scale
[
2
][
NUM_GROUPS
];
float
rscale
[
2
][
NUM_GROUPS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
scale
[
0
][
i
]
=
fminf
(
float
(
maxvalue
[
0
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
scale
[
1
][
i
]
=
fminf
(
float
(
maxvalue
[
1
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
scale
[
0
][
i
]
=
fminf
(
float
(
maxvalue
[
0
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
scale
[
1
][
i
]
=
fminf
(
float
(
maxvalue
[
1
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
// TODO: check whether (1 / scale) or (1 / fp8scale) is better
rscale
[
0
][
i
]
=
cuda_frcp
(
scale
[
0
][
i
]);
rscale
[
1
][
i
]
=
cuda_frcp
(
scale
[
1
][
i
]);
...
...
@@ -152,30 +153,29 @@ public:
if
(
laneId
%
4
/
2
==
ida
)
{
output_scale
=
(
laneId
%
2
==
0
)
?
fp8scale
[
0
]
:
fp8scale
[
1
];
}
uint32_t
qpacks
[
2
][
INSN_K
/
INSN_M
*
2
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
][
i
/
2
],
rscale
[
j
][
i
/
2
]);
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
][
i
/
2
],
rscale
[
j
][
i
/
2
]);
qpacks
[
j
][
i
]
=
quantize_float2_fp4
(
fval
)
<<
(
laneId
%
4
*
8
);
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
qpacks
[
j
][
i
]
|=
__shfl_xor_sync
(
~
0
,
qpacks
[
j
][
i
],
mask
);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
if
(
laneId
%
4
==
i
)
{
output
.
x
=
qpacks
[
0
][
0
+
i
];
...
...
@@ -188,88 +188,110 @@ public:
// m16n16k64 MMA
// ida, idb in {0, 1}
__device__
__forceinline__
static
packed_f32psum_t
mma_fp4
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_f32psum_t
psum
,
uint32_t
amscale
,
uint32_t
wmscale
,
int
ida
,
int
idb
)
{
__device__
__forceinline__
static
packed_f32psum_t
mma_fp4
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_f32psum_t
psum
,
uint32_t
amscale
,
uint32_t
wmscale
,
int
ida
,
int
idb
)
{
packed_f32psum_t
out
;
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
0
]),
"=f"
(
out
.
data
[
1
]),
"=f"
(
out
.
data
[
2
]),
"=f"
(
out
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"f"
(
psum
.
data
[
0
]),
"f"
(
psum
.
data
[
1
]),
"f"
(
psum
.
data
[
2
]),
"f"
(
psum
.
data
[
3
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
))
);
asm
volatile
(
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
0
]),
"=f"
(
out
.
data
[
1
]),
"=f"
(
out
.
data
[
2
]),
"=f"
(
out
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"f"
(
psum
.
data
[
0
]),
"f"
(
psum
.
data
[
1
]),
"f"
(
psum
.
data
[
2
]),
"f"
(
psum
.
data
[
3
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
)));
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
4
]),
"=f"
(
out
.
data
[
5
]),
"=f"
(
out
.
data
[
6
]),
"=f"
(
out
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"f"
(
psum
.
data
[
4
]),
"f"
(
psum
.
data
[
5
]),
"f"
(
psum
.
data
[
6
]),
"f"
(
psum
.
data
[
7
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
+
1
))
);
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
4
]),
"=f"
(
out
.
data
[
5
]),
"=f"
(
out
.
data
[
6
]),
"=f"
(
out
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"f"
(
psum
.
data
[
4
]),
"f"
(
psum
.
data
[
5
]),
"f"
(
psum
.
data
[
6
]),
"f"
(
psum
.
data
[
7
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
+
1
)));
return
out
;
}
__device__
__forceinline__
static
void
compute_fp4
(
act_warp
A
,
wgt_warp
W
,
amscale_warp
amscale
,
wmscale_warp
wmscale
,
f32psum_warp
&
psum
)
{
__device__
__forceinline__
static
void
compute_fp4
(
act_warp
A
,
wgt_warp
W
,
amscale_warp
amscale
,
wmscale_warp
wmscale
,
f32psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma_fp4
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
],
amscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
wmscale
[
j
/
2
/
WMSCALES_PACK_SIZE
].
data
[
j
/
2
%
WMSCALES_PACK_SIZE
],
i
%
2
,
j
%
2
);
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma_fp4
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
],
amscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
wmscale
[
j
/
2
/
WMSCALES_PACK_SIZE
].
data
[
j
/
2
%
WMSCALES_PACK_SIZE
],
i
%
2
,
j
%
2
);
}
}
}
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
__device__
__forceinline__
static
void
gemm_w4a4_fp4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
// per-tensor scale of weight
int
M
,
int
N
,
int
K
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
gemm_w4a4_fp4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
// per-tensor scale of weight
int
M
,
int
N
,
int
K
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8 * 2
wgt_warp
W
[
NUM_STAGES
];
// 32 * 2
amscale_warp
amscale
[
NUM_STAGES
];
// 1 * 2
wmscale_warp
wmscale
[
NUM_STAGES
];
// 4 * 2
f32psum_warp
fpsum
;
// 128
act_warp
A
[
NUM_STAGES
];
// 8 * 2
wgt_warp
W
[
NUM_STAGES
];
// 32 * 2
amscale_warp
amscale
[
NUM_STAGES
];
// 1 * 2
wmscale_warp
wmscale
[
NUM_STAGES
];
// 4 * 2
f32psum_warp
fpsum
;
// 128
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
...
...
@@ -278,21 +300,21 @@ public:
load_wmscale
(
wscales
,
k
,
wmscale
[
k
],
true
);
}
#pragma unroll
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
=
0
;
}
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
...
...
@@ -317,15 +339,14 @@ public:
unused_var
(
dummy
,
alwaysfalse
);
if
constexpr
(
USE_ALPHA
)
{
#pragma unroll
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
*=
alpha
;
}
}
}
auto
f16psum
=
packed_fp32_to_fp16
(
fpsum
);
...
...
@@ -337,21 +358,20 @@ public:
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
struct
gemm_w4a4_fp4_kernel
{
static
constexpr
int
MIN_ARCH
=
1200
;
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
...
...
@@ -372,27 +392,27 @@ public:
ascales
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
,
wscales
+
bn
*
(
K
/
WARP_K
)
*
WMSCALES_NUM_PACKS
*
WMSCALES_VALID_LANES
,
alpha
,
M
,
N
,
K
,
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
alwaysfalse
);
}
else
{
trap_no_fp4
();
}
}
};
public:
template
<
bool
ACT_UNSIGNED
>
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
)
{
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
)
{
packed_psum_t
psum
;
uint4
out1
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
x
,
wgt
.
y
),
uint4
(
0
,
0
,
0
,
0
));
uint4
out2
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
z
,
wgt
.
w
),
uint4
(
0
,
0
,
0
,
0
));
uint4
out1
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
x
,
wgt
.
y
),
uint4
(
0
,
0
,
0
,
0
));
uint4
out2
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
z
,
wgt
.
w
),
uint4
(
0
,
0
,
0
,
0
));
psum
.
data
[
0
]
=
out1
.
x
;
psum
.
data
[
1
]
=
out1
.
y
;
psum
.
data
[
2
]
=
out1
.
z
;
...
...
@@ -401,29 +421,30 @@ public:
psum
.
data
[
5
]
=
out2
.
y
;
psum
.
data
[
6
]
=
out2
.
z
;
psum
.
data
[
7
]
=
out2
.
w
;
return
psum
;
}
// template<bool si>
template
<
bool
use_unsigned
>
__device__
__forceinline__
static
void
quantize_w4a4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
half_t
*
output_scale
)
{
__device__
__forceinline__
static
void
quantize_w4a4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
half_t
*
output_scale
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
float
QVALUE_MAX_SIGNED
=
7.0
f
;
constexpr
float
QVALUE_MAX_UNSIGNED
=
15.0
f
;
constexpr
float
RECPI_QVALUE_MAX_SIGNED
=
1
/
QVALUE_MAX_SIGNED
;
constexpr
float
QVALUE_MAX_SIGNED
=
7.0
f
;
constexpr
float
QVALUE_MAX_UNSIGNED
=
15.0
f
;
constexpr
float
RECPI_QVALUE_MAX_SIGNED
=
1
/
QVALUE_MAX_SIGNED
;
constexpr
float
RECPI_QVALUE_MAX_UNSIGNED
=
1
/
QVALUE_MAX_UNSIGNED
;
constexpr
float
QVALUE_MAX
=
use_unsigned
?
QVALUE_MAX_UNSIGNED
:
QVALUE_MAX_SIGNED
;
constexpr
float
QVALUE_MAX
=
use_unsigned
?
QVALUE_MAX_UNSIGNED
:
QVALUE_MAX_SIGNED
;
constexpr
float
RECPI_QVALUE_MAX
=
use_unsigned
?
RECPI_QVALUE_MAX_UNSIGNED
:
RECPI_QVALUE_MAX_SIGNED
;
// constexpr int QUANTIZE_BITMASK = 0xf;
// 0 for row 0-7; 1 for row 8-15
half2_t
input
[
2
][
INSN_K
/
INSN_N
*
2
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_N
;
i
++
)
{
input
[
0
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
0
];
input
[
0
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
2
];
...
...
@@ -434,14 +455,14 @@ public:
half_t
maxvalue
[
2
];
maxvalue
[
0
]
=
0
;
maxvalue
[
1
]
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
half2_t
abs0
=
__habs2
(
input
[
0
][
i
]);
half2_t
abs1
=
__habs2
(
input
[
1
][
i
]);
maxvalue
[
0
]
=
__hmax
(
maxvalue
[
0
],
__hmax
(
abs0
.
x
,
abs0
.
y
));
maxvalue
[
1
]
=
__hmax
(
maxvalue
[
1
],
__hmax
(
abs1
.
x
,
abs1
.
y
));
maxvalue
[
0
]
=
__hmax
(
maxvalue
[
0
],
__hmax
(
abs0
.
x
,
abs0
.
y
));
maxvalue
[
1
]
=
__hmax
(
maxvalue
[
1
],
__hmax
(
abs1
.
x
,
abs1
.
y
));
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue
[
0
]
=
__hmax
(
maxvalue
[
0
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
0
],
mask
));
maxvalue
[
1
]
=
__hmax
(
maxvalue
[
1
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
1
],
mask
));
...
...
@@ -455,7 +476,7 @@ public:
scale
[
0
]
=
float
(
maxvalue
[
0
])
*
RECPI_QVALUE_MAX
;
scale
[
1
]
=
float
(
maxvalue
[
1
])
*
RECPI_QVALUE_MAX
;
if
(
laneId
%
4
==
0
)
{
output_scale
[
laneId
/
4
]
=
half_t
(
scale
[
0
]);
output_scale
[
laneId
/
4
]
=
half_t
(
scale
[
0
]);
output_scale
[
laneId
/
4
+
8
]
=
half_t
(
scale
[
1
]);
}
...
...
@@ -466,23 +487,23 @@ public:
rscale
[
1
]
=
cuda_frcp
(
scale
[
1
]);
uint32_t
qpacks
[
2
][
INSN_K
/
INSN_M
*
2
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
// half2_t hval = __hmul2(input[j][i], half2_t(rscale[j], rscale[j]));
// float2 fval = half22float2(hval);
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
],
rscale
[
j
]);
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
],
rscale
[
j
]);
qpacks
[
j
][
i
]
=
quantize_float2
<
4
,
use_unsigned
>
(
fval
)
<<
(
laneId
%
4
*
8
);
}
}
// 2 * 8 * 2 = 32 instructions => 256 cycles
#pragma unroll
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
qpacks
[
j
][
i
]
|=
__shfl_xor_sync
(
~
0
,
qpacks
[
j
][
i
],
mask
);
}
...
...
@@ -490,7 +511,7 @@ public:
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
if
(
laneId
%
4
==
i
)
{
output
.
x
=
qpacks
[
0
][
0
+
i
];
...
...
@@ -501,73 +522,74 @@ public:
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
* input is per-warp (in global memory)
* output is per-thread (in regs)
* output_scale is per-warp (in shared memory)
* shmem must be at least INSN_M * INSN_K * sizeof(element) (16 * 64 * 0.5 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
__device__
__forceinline__
static
void
quantize_w4a4_warp
(
const
half_t
*
input
,
int
stride
,
packed_act_t
&
output
,
half_t
*
output_scale
,
void
*
shmem
)
{
__device__
__forceinline__
static
void
quantize_w4a4_warp
(
const
half_t
*
input
,
int
stride
,
packed_act_t
&
output
,
half_t
*
output_scale
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
QUANTIZE_BITWIDTH
=
4
;
constexpr
int
QVALUE_MAX
=
7
;
// 4 bit => [-8, 7]
constexpr
int
QVALUE_MAX
=
7
;
// 4 bit => [-8, 7]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 8 for 4bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
// a pack is {a0, ..., a7} in figure
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
// INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 8 for 4bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_input
packs
[
NUM_PACKWARPS
];
// load
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
packs
[
i
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
packs
[
i
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
}
// find max
half_t
maxvalue
[
NUM_PACKWARPS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
maxvalue
[
i
]
=
__habs
(
packs
[
i
][
0
]);
#pragma unroll
maxvalue
[
i
]
=
__habs
(
packs
[
i
][
0
]);
#pragma unroll
for
(
int
j
=
1
;
j
<
PACK_SIZE
;
j
++
)
{
maxvalue
[
i
]
=
__hmax
(
maxvalue
[
i
],
__habs
(
packs
[
i
][
j
]));
}
}
// warp reduce (max)
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_PACKS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
maxvalue
[
i
]
=
__hmax
(
maxvalue
[
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
i
],
mask
));
}
}
// broadcast (max)
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
maxvalue
[
i
]
=
__shfl_sync
(
~
0
,
maxvalue
[
i
],
laneId
/
NUM_PACKS_PER_ROW
*
NUM_PACKS_PER_ROW
);
}
// quantize
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
half_t
scale
=
maxvalue
[
i
]
/
half_t
(
QVALUE_MAX
);
half_t
rscale
=
half_t
(
QVALUE_MAX
)
/
maxvalue
[
i
];
...
...
@@ -576,13 +598,13 @@ public:
}
uint32_t
qpack
=
0
;
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
half2_t
hval
=
__hmul2
(
half2_t
(
rscale
,
rscale
),
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]));
qpack
|=
quantize_float2
<
QUANTIZE_BITWIDTH
,
false
>
(
half22float2
(
hval
))
<<
(
j
*
QUANTIZE_BITWIDTH
);
...
...
@@ -590,7 +612,7 @@ public:
mat
[
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
][
laneId
%
NUM_PACKS_PER_ROW
]
=
qpack
;
}
__syncwarp
();
// convert to imma format
int
row
=
laneId
%
16
;
int
col
=
laneId
/
16
*
4
;
...
...
@@ -602,12 +624,11 @@ public:
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct
quantize_w4a4_act_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
)
{
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
bk
=
blockIdx
.
y
;
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
bk
=
blockIdx
.
y
;
const
int
warpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
const
int
row
=
blockIdx
.
x
*
WARP_M
;
...
...
@@ -620,28 +641,27 @@ public:
packed_act_t
tmpout
;
quantize_w4a4_warp
(
input
+
(
row
+
tileId
*
INSN_M
)
*
K
+
col
,
K
,
tmpout
,
oscale_shmem
+
tileId
*
INSN_M
,
tmp_shmem
);
input
+
(
row
+
tileId
*
INSN_M
)
*
K
+
col
,
K
,
tmpout
,
oscale_shmem
+
tileId
*
INSN_M
,
tmp_shmem
);
store
(
&
output
[(((
bm
*
K
/
WARP_K
+
bk
)
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
tileId
)
*
WARP_SIZE
+
laneId
],
tmpout
);
store
(
&
output
[(((
bm
*
K
/
WARP_K
+
bk
)
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
tileId
)
*
WARP_SIZE
+
laneId
],
tmpout
);
}
// if (threadIdx.x == 0) {
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS + warpId);
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS
// + warpId);
// }
pack_ascales
(
oscale_shmem
,
&
oscales
[((
bm
*
K
/
WARP_K
+
bk
)
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
pack_ascales
(
oscale_shmem
,
&
oscales
[((
bm
*
K
/
WARP_K
+
bk
)
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
}
};
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct
quantize_w4a4_wgt_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
half_t
*
input
,
packed_wgt_t
*
output
,
packed_wscale_t
*
oscales
,
int
K
)
{
__device__
void
operator
()(
const
half_t
*
input
,
packed_wgt_t
*
output
,
packed_wscale_t
*
oscales
,
int
K
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
bn
=
blockIdx
.
x
/
(
BLOCK_N
/
WARP_N
);
...
...
@@ -657,12 +677,7 @@ public:
packed_wgt_t
tmpout
;
quantize_w4a4_warp
(
input
+
(
col
+
tileId
*
INSN_N
)
*
K
+
row
,
K
,
tmpout
,
oscale_shmem
+
tileId
*
INSN_N
,
tmp_shmem
);
input
+
(
col
+
tileId
*
INSN_N
)
*
K
+
row
,
K
,
tmpout
,
oscale_shmem
+
tileId
*
INSN_N
,
tmp_shmem
);
std
::
swap
(
tmpout
.
y
,
tmpout
.
z
);
...
...
@@ -674,59 +689,52 @@ public:
};
struct
i2f_sm80
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
float22half2
<
half2_t
>
(
int2float2
(
x
,
y
));
}
};
struct
i2f_sm75
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
half2
(
__int2half_rn
(
x
),
__int2half_rn
(
y
));
}
};
struct
i2f_sm75_fast
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
int2half2_fast_512
(
x
,
y
);
}
};
template
<
bool
ACT_UNSIGNED
,
typename
T
>
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
using
int2half2
=
i2f_sm80
;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using
int2half2
=
std
::
conditional_t
<
Config
::
FASTER_I2F
,
i2f_sm75_fast
,
i2f_sm75
>
;
;
#else
using
int2half2
=
Base
::
i2f_normal
;
#endif
Base
::
template
apply_scales
<
int2half2
>([
&
](
int
i
,
int
j
)
{
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
},
ascale
,
wscale
,
fpsum
);
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using
int2half2
=
std
::
conditional_t
<
Config
::
FASTER_I2F
,
i2f_sm75_fast
,
i2f_sm75
>
;
;
#else
using
int2half2
=
Base
::
i2f_normal
;
#endif
Base
::
template
apply_scales
<
int2half2
>(
[
&
](
int
i
,
int
j
)
{
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
},
ascale
,
wscale
,
fpsum
);
}
__device__
__forceinline__
static
void
checkNan
(
fpsum_warp
fpsum
,
const
char
*
info
=
""
)
{
__device__
__forceinline__
static
void
checkNan
(
fpsum_warp
fpsum
,
const
char
*
info
=
""
)
{
#if ENABLE_NAN_CHECK
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -735,14 +743,17 @@ public:
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
bool
abnormal
=
!
isfinite
((
float
)
fpsum
[
i
].
data
[
j
].
x
)
||
!
isfinite
((
float
)
fpsum
[
i
].
data
[
j
].
y
);
if
(
abnormal
)
{
printf
(
"abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d j=%d data.x=%f data.y=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
i
,
j
,
(
float
)
fpsum
[
i
].
data
[
j
].
x
,
(
float
)
fpsum
[
i
].
data
[
j
].
y
);
printf
(
"abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d "
"j=%d data.x=%f data.y=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
i
,
j
,
(
float
)
fpsum
[
i
].
data
[
j
].
x
,
(
float
)
fpsum
[
i
].
data
[
j
].
y
);
__trap
();
}
}
...
...
@@ -750,8 +761,7 @@ public:
#endif
}
__device__
__forceinline__
static
void
checkNan
(
packed_f32psum_t
fpsum
,
const
char
*
info
=
""
)
{
__device__
__forceinline__
static
void
checkNan
(
packed_f32psum_t
fpsum
,
const
char
*
info
=
""
)
{
#if ENABLE_NAN_CHECK
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -759,21 +769,22 @@ public:
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
bool
abnormal
=
!
isfinite
(
fpsum
.
data
[
j
]);
if
(
abnormal
)
{
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
j
,
fpsum
.
data
[
j
]
);
fpsum
.
data
[
j
]);
__trap
();
}
}
#endif
}
__device__
__forceinline__
static
void
checkNan
(
packed_fpsum_t
fpsum
,
const
char
*
info
=
""
)
{
__device__
__forceinline__
static
void
checkNan
(
packed_fpsum_t
fpsum
,
const
char
*
info
=
""
)
{
#if ENABLE_NAN_CHECK
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -781,34 +792,36 @@ public:
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
bool
abnormal
=
!
isfinite
((
float
)
fpsum
.
data
[
j
].
x
)
||
!
isfinite
((
float
)
fpsum
.
data
[
j
].
y
);
if
(
abnormal
)
{
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f data.y=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
j
,
(
float
)
fpsum
.
data
[
j
].
x
,
(
float
)
fpsum
.
data
[
j
].
y
);
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f "
"data.y=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
j
,
(
float
)
fpsum
.
data
[
j
].
x
,
(
float
)
fpsum
.
data
[
j
].
y
);
__trap
();
}
}
#endif
}
__device__
__forceinline__
static
void
checkNan
(
float
data
,
const
char
*
info
=
""
)
{
__device__
__forceinline__
static
void
checkNan
(
float
data
,
const
char
*
info
=
""
)
{
#if ENABLE_NAN_CHECK
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
bool
abnormal
=
!
isfinite
(
data
);
if
(
abnormal
)
{
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
data
);
printf
(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
warpId
,
laneId
,
info
,
data
);
__trap
();
}
#endif
...
...
@@ -816,19 +829,18 @@ public:
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
,
bool
USE_FP32_ACCUM
>
__device__
__forceinline__
static
void
gemm_w4a4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// const packed_wscale_t *bias_ptr,
// half_t *out,
int
M
,
int
N
,
int
K
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
gemm_w4a4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// const packed_wscale_t *bias_ptr,
// half_t *out,
int
M
,
int
N
,
int
K
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -838,11 +850,11 @@ public:
fpsum_warp fpsum;
GEMM_W4A4_Block<Config>()(act, wgt, ascales, wscales, K, fpsum, alwaysfalse);
#else
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
[
NUM_STAGES
];
// 1
wscale_warp
wscale
[
NUM_STAGES
];
// 2
std
::
conditional_t
<
USE_FP32_ACCUM
,
f32psum_warp
,
fpsum_warp
>
fpsum
;
// 64
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
[
NUM_STAGES
];
// 1
wscale_warp
wscale
[
NUM_STAGES
];
// 2
std
::
conditional_t
<
USE_FP32_ACCUM
,
f32psum_warp
,
fpsum_warp
>
fpsum
;
// 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
...
...
@@ -867,14 +879,14 @@ public:
}
}
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
...
...
@@ -889,11 +901,11 @@ public:
compute
<
ACT_UNSIGNED
>
(
A
[
k2
],
W
[
k2
],
ascale
[
k2
],
wscale
[
k2
],
fpsum
);
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
//#endif
//
#endif
// asm volatile ("membar.cta;");
}
...
...
@@ -927,11 +939,17 @@ public:
const
packed_wscale_t
*
smooth_factor
;
};
static
constexpr
int
NUM_PACKS
=
INSN_K
/
INSN_N
;
static
constexpr
int
NUM_PACKS
=
INSN_K
/
INSN_N
;
static
constexpr
int
NUM_GROUPS
=
WARP_N_TILES
/
NUM_PACKS
;
__device__
__forceinline__
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
oscales_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
__device__
__forceinline__
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
oscales_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -940,21 +958,21 @@ public:
wscale_warp
smooth
;
load_wscale
(
smooth_factor
,
0
,
N
,
smooth
,
true
);
#pragma unroll
#pragma unroll
for
(
int
group
=
0
;
group
<
NUM_GROUPS
;
group
++
)
{
amscale_warp
omscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
tmp
[
NUM_PACKS
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_PACKS
;
j
++
)
{
half2_t
ws1
=
broadcast_wscale
(
smooth
,
(
group
*
NUM_PACKS
+
j
)
*
4
,
laneId
);
half2_t
ws2
=
broadcast_wscale
(
smooth
,
(
group
*
NUM_PACKS
+
j
)
*
4
+
2
,
laneId
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
group
*
NUM_PACKS
+
j
].
data
[
k
];
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
group
*
NUM_PACKS
+
j
].
data
[
k
];
half2_t
&
dst
=
tmp
[
j
].
data
[
k
];
// dst.x = gelu(src.x);
...
...
@@ -977,7 +995,8 @@ public:
packed_act_t
qresult
;
if
constexpr
(
USE_FP4
)
{
quantize_w4a4_fp4_from_fpsum_warp
(
tmp
,
qresult
,
omscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
i
%
2
);
quantize_w4a4_fp4_from_fpsum_warp
(
tmp
,
qresult
,
omscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
i
%
2
);
}
else
{
quantize_w4a4_from_fpsum_warp
<
USE_UNSIGNED
>
(
tmp
,
qresult
,
&
oscale_shmem
[
warpId
][
i
*
INSN_M
]);
}
...
...
@@ -985,34 +1004,38 @@ public:
}
if
constexpr
(
USE_FP4
)
{
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
AMSCALES_NUM_PACKS
;
k
++
)
{
store
(
&
oscales
[((
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
+
k
)
*
AMSCALES_VALID_LANES
+
laneId
],
omscale
[
k
]);
store
(
&
oscales
[((
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
+
k
)
*
AMSCALES_VALID_LANES
+
laneId
],
omscale
[
k
]);
}
}
if
constexpr
(
!
USE_FP4
)
{
__syncwarp
();
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
__syncwarp
();
}
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
constexpr
(
!
USE_FP4
||
FP4_AVAILABLE
)
{
apply_quantize
(
fpsum
,
M
,
N
,
K
,
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
(
USE_FP4
?
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
:
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
),
args
.
shift_value
,
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
apply_quantize
(
fpsum
,
M
,
N
,
K
,
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
(
USE_FP4
?
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
:
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
),
args
.
shift_value
,
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
else
{
trap_no_fp4
();
}
...
...
@@ -1025,22 +1048,21 @@ public:
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
int
MAX_ARCH
=
Config
::
FASTER_I2F
?
750
:
INT_MAX
;
// FASTER_I2F is only needed on sm_75
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
...
...
@@ -1064,20 +1086,21 @@ public:
// bias ? bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES : nullptr,
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * N / BLOCK_N + bn) * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE,
M
,
N
,
K
,
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
alwaysfalse
);
}
};
template
<
bool
fuse_glu
,
bool
use_fp4
>
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
Base
::
template
load_act_to_fpsum
<
fuse_glu
>
::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
Base
::
template
load_act_to_fpsum
<
fuse_glu
>
::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
struct
Arguments
{
...
...
@@ -1091,25 +1114,23 @@ public:
int
lora_rank
;
// aligned to BLOCK_M and BLOCK_N
int
M
,
N
;
// N should be the actual K in the next GEMM (needs /2 if fuse_glu)
int
M
,
N
;
// N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int
actualM
,
actualN
;
bool
alwaysfalse
;
};
__device__
__forceinline__
void
operator
()(
Arguments
args
)
{
__device__
__forceinline__
void
operator
()(
Arguments
args
)
{
const
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
m_offset
=
bm
*
BLOCK_M
+
warpId
*
WARP_M
;
...
...
@@ -1119,42 +1140,48 @@ public:
fpsum_warp
fpsum
;
Base
::
template
load_act_to_fpsum
<
fuse_glu
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
Base
::
template
load_act_to_fpsum
<
fuse_glu
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
using
EpilogueLoraDown
=
typename
Lora
<
Config
>::
EpilogueLoraDown
;
EpilogueLoraDown
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
args
.
lora_wgt_down
,
.
lora_act
=
args
.
lora_act
,
.
rank
=
args
.
lora_rank
,
.
alwaysfalse
=
args
.
alwaysfalse
,
});
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{
.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
smooth_factor
=
args
.
smooth_factor
});
EpilogueLoraDown
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
args
.
lora_wgt_down
,
.
lora_act
=
args
.
lora_act
,
.
rank
=
args
.
lora_rank
,
.
alwaysfalse
=
args
.
alwaysfalse
,
});
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
smooth_factor
=
args
.
smooth_factor
});
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
57e50f8d
...
...
@@ -5,57 +5,61 @@ namespace nunchaku::kernels {
template
<
typename
Config
,
bool
USE_FP4
>
class
GEMM_W4A4_Launch
{
using
GEMM
=
GEMM_W4A4
<
Config
>
;
using
GEMM
=
GEMM_W4A4
<
Config
>
;
using
Epilogues
=
Epilogues
<
Config
>
;
using
Lora
=
Lora
<
Config
>
;
using
Lora
=
Lora
<
Config
>
;
using
packed_act_t
=
typename
GEMM
::
packed_act_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_act_t
=
typename
GEMM
::
packed_act_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_amscale_t
=
typename
GEMM
::
packed_amscale_t
;
using
packed_wmscale_t
=
typename
GEMM
::
packed_wmscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
public:
static
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
};
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
View file @
57e50f8d
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
};
\ No newline at end of file
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
};
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
View file @
57e50f8d
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
};
\ No newline at end of file
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
View file @
57e50f8d
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
};
\ No newline at end of file
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
View file @
57e50f8d
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
};
\ No newline at end of file
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
View file @
57e50f8d
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
};
\ No newline at end of file
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
57e50f8d
...
...
@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
template
<
>
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>::
gemm_w4a4
(
#endif
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
#ifdef __INTELLISENSE__
static
constexpr
bool
USE_FP4
=
false
;
#endif
...
...
@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
...
...
@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
...
...
@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
...
...
@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
EpilogueBias
::
Arguments
{
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
scale
=
USE_SCALE
?
wcscale
s
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
}
,
nextArgs
,
{}
});
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
Epilogue
Nop
>
;
return
launch
.
template
operator
()
<
Epilogue
>(
{
typename
EpilogueBias
::
Arguments
{
.
bias
=
USE_BIAS
?
bia
s
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
}
,
nextArgs
,
{}
});
});
});
};
// auto launch_bias = launch;
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
const
int
rank_up
=
lora_up
.
valid
()
?
lora_up
.
shape
[
1
]
:
0
;
const
int
rank_up
=
lora_up
.
valid
()
?
lora_up
.
shape
[
1
]
:
0
;
const
int
rank_down
=
lora_down
.
valid
()
?
lora_down
.
shape
[
1
]
:
0
;
if
(
rank_up
==
0
)
{
assert
(
rank_down
==
0
);
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
({
midArgs
,
nextArgs
});
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
(
{
midArgs
,
nextArgs
});
}
assert
(
rank_up
%
16
==
0
);
assert
(
lora_up
.
shape
[
0
]
==
N
);
...
...
@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
lora_act_in
.
shape
[
0
]
==
M
);
assert
(
lora_act_in
.
shape
[
1
]
==
rank_up
);
using
LoraUp
=
Lora
;
using
LoraUp
=
Lora
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
scale_t
scales
;
...
...
@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
}
if
(
rank_down
==
0
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
nextArgs
,
{}
});
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
nextArgs
,
{}});
}
// assert(rank_down == rank_up);
...
...
@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
rank
=
rank_down
,
.
alwaysfalse
=
false
,
},
nextArgs
,
{}
});
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
rank
=
rank_down
,
.
alwaysfalse
=
false
,
},
nextArgs
,
{}});
// });
};
...
...
@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
Epilogues
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
...
...
@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
out_vk
.
shape
[
3
]
==
Epilogue
::
LITELA_HEAD_DIM
);
assert
(
out_vk
.
shape
[
1
]
*
Epilogue
::
LITELA_HEAD_DIM
*
3
==
N
);
int
batch_size
=
out_vk
.
shape
[
0
];
int
num_heads
=
out_vk
.
shape
[
1
];
int
num_heads
=
out_vk
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
out_linearattn
.
dtype
()));
assert
(
out_linearattn
.
ndims
()
==
3
);
...
...
@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
out_vk
.
zero_
();
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
num_blocks_per_batch
=
num_blocks_per_batch
,
.
actualM
=
M
,
},
{});
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
num_blocks_per_batch
=
num_blocks_per_batch
,
.
actualM
=
M
,
},
{});
}
else
if
(
rotary_emb
.
valid
())
{
assert
(
norm_q
.
valid
());
...
...
@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
2
]
==
Epilogues
::
EpilogueRMSNormRope
::
HEAD_DIM
);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
// GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
// GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
...
...
@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// }, {});
using
EpilogueRope
=
typename
Epilogues
::
EpilogueRMSNormRope
;
auto
argsRope
=
typename
Epilogues
::
EpilogueRMSNormRope
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
epsilon
=
1e-6
,
auto
argsRope
=
typename
Epilogues
::
EpilogueRMSNormRope
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
epsilon
=
1e-6
,
};
if
(
out_q
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
},
{});
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
(
{
argsRope
,
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}},
{});
}
else
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
},
{});
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}},
{});
}
}
else
if
(
out
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueDefault
;
typename
Epilogue
::
Arguments
args
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
};
...
...
@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
using
Epilogue
=
typename
Epilogues
::
EpilogueLiteLA
;
int
batch_size
=
vk
.
shape
[
0
];
int
num_heads
=
vk
.
shape
[
1
];
int
num_heads
=
vk
.
shape
[
1
];
int
num_tokens
=
q
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
q
.
scalar_type
()));
...
...
@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE
=
128
;
}
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
><<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
<<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
...
@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_rank
=
rank
,
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
}
);
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_rank
=
rank
,
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
});
checkCUDA
(
cudaGetLastError
());
});
// });
...
...
@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
if
constexpr
(
USE_FP4
)
{
assert
(
false
);
// not implemented
assert
(
false
);
// not implemented
return
;
}
...
...
@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
...
...
@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert
(
output
.
ndims
()
==
2
);
assert
(
output
.
shape
[
0
]
==
N
);
assert
(
output
.
shape
[
1
]
==
K
/
2
);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
// assert(oscales.dtype() == Tensor::FP16);
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w4a4_test.cu
View file @
57e50f8d
...
...
@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpilogueRMSNormRope
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
...
@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
Epilogue
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
}
}
);
typename
kernel
::
Arguments
{.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
Epilogue
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
}});
checkCUDA
(
cudaGetLastError
());
}
...
...
@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor
output
=
Tensor
::
empty_like
(
input
);
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpiloguePackQKV
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
...
@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
actualM
=
numTokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
}
);
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}});
checkCUDA
(
cudaGetLastError
());
}
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w8a8.cu
View file @
57e50f8d
...
...
@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
assert
(
oscales
.
numel
()
==
M
*
1
);
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
assert
(
kernel
::
check
(
M
,
K
));
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
checkCUDA
(
cudaGetLastError
());
};
...
...
@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
}
}
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
)
{
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
...
...
@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>><<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
};
...
...
@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K]
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
// Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
{}
});
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
{}});
};
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
});
...
...
@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela(
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
...
...
@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela(
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
M, N, K, epilogueArgs,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
...
...
@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela(
}
#endif
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/gemm_w8a8.cuh
View file @
57e50f8d
...
...
@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using
psum_warp
=
std
::
array
<
packed_psum_t
,
WARP_M_TILES
*
WARP_N_TILES
>
;
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
// packed_psum_t psum;
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
])
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
])
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
]));
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
]));
return
psum
;
}
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
]);
}
...
...
@@ -62,11 +66,12 @@ public:
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
template
<
bool
input_shmem
=
false
>
__device__
__forceinline__
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
__device__
__forceinline__
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
QUANTIZE_BITWIDTH
=
8
;
...
...
@@ -75,28 +80,29 @@ public:
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
// a pack is {a0, ..., a7} in figure
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
// INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_input
packs
[
NUM_PACKWARPS
];
// load
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
packs
[
i
]
=
load
<
input_shmem
>
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
packs
[
i
]
=
load
<
input_shmem
>
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
}
// quantize
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
const
int
row
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
const
int
col
=
laneId
%
NUM_PACKS_PER_ROW
;
...
...
@@ -104,7 +110,7 @@ public:
float
rscale
=
cuda_frcp
(
float
(
oscales
[
row
]));
uint32_t
qpack
=
0
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2
fval
=
half22float2
(
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]))
*
float2
(
rscale
,
rscale
);
...
...
@@ -113,7 +119,7 @@ public:
mat
[
row
][
col
]
=
qpack
;
}
__syncwarp
();
// convert to imma format
int
row
=
laneId
%
16
;
int
col
=
laneId
/
16
*
4
;
...
...
@@ -126,20 +132,20 @@ public:
* each warp finds absmax from a row
*/
template
<
bool
fuse_glu
=
false
>
__device__
__forceinline__
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
using
packed_gated_input
=
std
::
array
<
half_t
,
4
>
;
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
NUM_STAGES
=
2
;
half2_t
maxvalue2
=
{
0
,
0
};
half2_t
maxvalue2
=
{
0
,
0
};
packed_input
pack
[
NUM_STAGES
];
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
const
int
idx
=
k
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
if
(
idx
<
K
)
{
...
...
@@ -155,11 +161,11 @@ public:
// TODO: store quantized data to shmem (instead of half)
for
(
int
k1
=
0
;
k1
<
ceilDiv
(
K
,
PACK_SIZE
*
WARP_SIZE
);
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
const
int
nextidx
=
(
k1
+
k2
+
NUM_STAGES
-
1
)
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
const
int
nextk2
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
const
int
nextk2
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
if
(
nextidx
<
K
)
{
pack
[
nextk2
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
&
input
[
nextidx
]));
...
...
@@ -172,11 +178,11 @@ public:
if
constexpr
(
fuse_glu
)
{
packed_gated_input
gated
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
gated
[
j
]
=
p
[
j
].
x
*
gelu_half
(
p
[
j
].
y
);
p
[
j
].
x
=
gated
[
j
];
p
[
j
].
y
=
0
;
p
[
j
].
x
=
gated
[
j
];
p
[
j
].
y
=
0
;
}
int
idx
=
(
k1
+
k2
)
*
PACK_SIZE
/
2
*
WARP_SIZE
+
laneId
*
PACK_SIZE
/
2
;
...
...
@@ -185,7 +191,7 @@ public:
}
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__habs2
(
p
[
j
]));
}
...
...
@@ -194,7 +200,7 @@ public:
// unused_var(dummy, alwaysfalse);
#pragma unroll
#pragma unroll
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor_sync
(
~
0
,
maxvalue2
,
mask
));
}
...
...
@@ -223,8 +229,8 @@ public:
return
INSN_M
*
K2
*
sizeof
(
half_t
);
}
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
// for quantize kernel
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -232,10 +238,9 @@ public:
const
int
numWarps
=
blockDim
.
x
/
WARP_SIZE
;
// for GEMM kernel
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
__shared__
alignas
(
128
)
half_t
oscale_shmem
[
WARP_M
];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__
alignas
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
...
...
@@ -249,7 +254,7 @@ public:
for
(
int
tileM
=
0
;
tileM
<
WARP_M_TILES
;
tileM
++
)
{
for
(
int
i
=
warpId
;
i
<
INSN_M
;
i
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
+
i
;
const
int
rowLocal
=
tileM
*
INSN_M
+
i
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
half_t
maxv
=
findmax_warp
<
fuse_glu
>
(
input
+
rowGlobal
*
K
,
shmem
+
i
*
K2
,
K
,
alwaysfalse
);
...
...
@@ -260,76 +265,66 @@ public:
__syncthreads
();
for
(
int
bk
=
warpId
;
bk
<
K2
/
WARP_K
;
bk
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
;
const
int
rowLocal
=
tileM
*
INSN_M
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
const
int
col
=
bk
*
WARP_K
;
const
int
col
=
bk
*
WARP_K
;
packed_act_t
tmpout
;
if
constexpr
(
fuse_glu
)
{
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]);
}
else
{
quantize_w8a8_warp
<
false
>
(
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]);
}
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
}
__syncthreads
();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
}
};
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
gated_fpsum_warp
result
;
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half_t
&
dst
=
result
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
dst
=
src
.
x
*
gelu_half
(
src
.
y
);
dst
=
src
.
x
*
gelu_half
(
src
.
y
);
}
}
}
return
result
;
}
static
constexpr
int
unpack_gated_fpsum_shmem_size
=
INSN_M
*
(
WARP_N
/
2
+
8
)
*
sizeof
(
half_t
);
__device__
__forceinline__
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
__device__
__forceinline__
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
PACK_SIZE
=
WARP_N
/
2
/
WARP_SIZE
;
using
pack_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
using
pack_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
// +8 to prevent bank conflicts
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
/
2
+
8
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
packed_gated_fpsum_t
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
int
row
=
laneId
/
4
;
int
col
=
laneId
%
4
+
j
*
INSN_N
/
2
;
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
0
])
=
fsum
.
data
[
0
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
4
])
=
fsum
.
data
[
2
];
packed_gated_fpsum_t
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
int
row
=
laneId
/
4
;
int
col
=
laneId
%
4
+
j
*
INSN_N
/
2
;
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
0
])
=
fsum
.
data
[
0
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
4
])
=
fsum
.
data
[
2
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
1
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
3
];
}
...
...
@@ -345,28 +340,27 @@ public:
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template
<
typename
Epilogue
>
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogeParams
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogeParams
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
;
// 1
wscale_warp
wscale
;
// 2
psum_warp
psum
;
// 128
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
;
// 1
wscale_warp
wscale
;
// 2
psum_warp
psum
;
// 128
for
(
auto
&
pack
:
psum
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
...
...
@@ -377,7 +371,7 @@ public:
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale
(
ascales
,
0
,
M
,
ascale
,
true
);
load_wscale
(
wscales
,
0
,
N
,
wscale
,
true
);
...
...
@@ -385,14 +379,14 @@ public:
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
...
...
@@ -421,17 +415,15 @@ public:
f32psum_warp
f32psum
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
f32psum
.
size
();
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
f32psum
[
i
].
data
[
j
]
=
0
;
}
}
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
fpsum_warp
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
...
...
@@ -443,27 +435,24 @@ public:
Epilogue
()(
binfo
,
fpsum
,
M
,
N
,
K
,
epilogeParams
);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
...
...
@@ -476,25 +465,25 @@ public:
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
}
};
#if 0
struct EpilogueGLU {
struct Arguments { size_t unused; };
...
...
@@ -510,9 +499,6 @@ public:
}
};
#endif
};
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/lora.cuh
View file @
57e50f8d
...
...
@@ -2,7 +2,6 @@
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
...
...
@@ -21,7 +20,7 @@ public:
public:
static
constexpr
int
MAX_RANK
=
1024
;
static
constexpr
int
WARP_R
=
16
;
static
constexpr
int
WARP_R
=
16
;
// static constexpr int LORA_RANK = rank;
static
constexpr
int
LORA_M_TILES
=
WARP_M
/
16
;
...
...
@@ -30,57 +29,57 @@ public:
static_assert
(
LORA_M_TILES
==
WARP_M_TILES
);
static_assert
(
LORA_N_TILES
==
WARP_N_TILES
);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using
lora_act_warp
=
std
::
array
<
packed_f32psum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act16_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
scale_t
=
std
::
array
<
float
,
MAX_RANK
/
16
>
;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__
__forceinline__
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
packed_fpsum_t
*
ptr_lane
=
&
ptr
[
rtile
*
LORA_R_TILES
*
WARP_SIZE
+
laneId
];
const
int
stride_ntile
=
rank
/
16
*
WARP_SIZE
;
const
int
stride_ntile
=
rank
/
16
*
WARP_SIZE
;
unrolled_loop
<
LORA_N_TILES
>
([
&
]
<
int
n
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
()
{
constexpr
int
roffset
=
r
*
WARP_SIZE
;
const
int
noffset
=
n
*
stride_ntile
;
constexpr
int
roffset
=
r
*
WARP_SIZE
;
const
int
noffset
=
n
*
stride_ntile
;
result
[
n
*
LORA_R_TILES
+
r
]
=
load_pred
(
ptr_lane
+
noffset
+
roffset
,
pred
);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__
__forceinline__
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
result
[
i
].
data
[
j
]
=
load_pred
(
ptrlane
+
offset
,
pred
);
// * scales[rtile * LORA_R_TILES + r];
result
[
i
].
data
[
j
]
=
load_pred
(
ptrlane
+
offset
,
pred
);
// * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -108,7 +107,6 @@ public:
// });
// }
struct
EpilogueLoraUp
{
struct
Arguments
{
const
float
*
lora_act
;
...
...
@@ -120,19 +118,23 @@ public:
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_act_warp
lora_act
[
NUM_STAGES
];
// 32
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
lora_act_warp
lora_act
[
NUM_STAGES
];
// 32
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
int
dummy
=
0
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
const
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
...
@@ -140,14 +142,14 @@ public:
load_lora_wgt
(
wgt
,
0
,
rank
,
lora_wgt
[
k
],
pred
);
}
f32psum_warp
f32psum
=
packed_fp16_to_fp32
(
fpsum
);
// 128
f32psum_warp
f32psum
=
packed_fp16_to_fp32
(
fpsum
);
// 128
auto
compute
=
[
&
scales
](
lora_act_warp
A
,
lora_wgt_warp
W
,
f32psum_warp
&
f32psum
,
int
rtile
)
ALWAYSINLINE
{
lora_act16_warp
A_fp16
;
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
packed_f32psum_t
pack
=
A
[
m
*
LORA_R_TILES
+
r
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
pack
.
data
[
j
]
*=
scales
[
rtile
*
LORA_R_TILES
+
r
];
}
...
...
@@ -159,28 +161,28 @@ public:
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
}
}
}
};
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
act
+=
kernels
::
bit_cast
<
int
>
(
lora_act
[
k2
][
0
].
data
[
0
]);
}
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
...
...
@@ -194,25 +196,24 @@ public:
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing
(packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
// the branch splits the basic blocks and prevents the overlap of memory access and computing
//
(packed_fp16_to_fp32)
add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_act
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
...
...
@@ -220,21 +221,20 @@ public:
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
CHECK_NAN
(
fpsum
,
"fpsum"
);
apply_lora_up
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
scales
,
args
.
rank
,
args
.
alwaysfalse
);
apply_lora_up
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
scales
,
args
.
rank
,
args
.
alwaysfalse
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
}
...
...
@@ -250,16 +250,16 @@ public:
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
...
@@ -270,11 +270,11 @@ public:
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
...
...
@@ -294,14 +294,14 @@ public:
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
...
...
@@ -324,38 +324,33 @@ public:
}
}
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
apply_lora_down
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
rank
,
args
.
alwaysfalse
);
apply_lora_down
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
rank
,
args
.
alwaysfalse
);
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/mma.cuh
View file @
57e50f8d
...
...
@@ -7,183 +7,169 @@
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
));
#else
static_assert
(
!
is_bf16
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
));
#endif
return
d
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/mma_earlycuda.cuh
View file @
57e50f8d
...
...
@@ -6,156 +6,118 @@
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
return
d
;
}
#endif
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
...
...
@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
return
d
;
}
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
...
...
@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
#else
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
return
d
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/kernels/zgemm/zgemm.h
View file @
57e50f8d
...
...
@@ -5,50 +5,55 @@
namespace
nunchaku
::
kernels
{
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
// packed ws [N]
);
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
// packed ws [N]
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
...
...
@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
);
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
);
// EXPERIMENTAL, for sm_75
void
set_faster_i2f_mode
(
std
::
string
mode
);
...
...
@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode);
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
);
};
// namespace nunchaku::kernels
\ No newline at end of file
};
// namespace nunchaku::kernels
src/layernorm.cpp
View file @
57e50f8d
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
eps
(
eps
)
{
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
eps
(
eps
)
{
if
(
elementwise_affine
)
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
bias
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
bias
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
}
registerParams
(
weight
,
"weight"
)
(
bias
,
"bias"
)
;
registerParams
(
weight
,
"weight"
)(
bias
,
"bias"
);
}
Tensor
LayerNorm
::
forward
(
Tensor
x
)
{
...
...
@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
return
out
;
}
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
src/layernorm.h
View file @
57e50f8d
...
...
@@ -20,9 +20,8 @@ private:
class
RMSNorm
:
public
Module
{
public:
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
{
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
registerParams
(
weight
,
"weight"
);
}
...
...
@@ -36,13 +35,16 @@ public:
class
RMSNormGeneral
{
friend
class
LlamaDecoderLayer
;
public:
RMSNormGeneral
(
int
hidden_size
,
bool
act_sum
,
float
eps
,
bool
use_per_token_quant
,
Device
device
)
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
{
RMSNormGeneral
(
int
hidden_size
,
bool
act_sum
,
float
eps
,
bool
use_per_token_quant
,
Device
device
)
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
{
this
->
weight
=
Tensor
::
ones
({
hidden_size
},
Tensor
::
FP32
,
device
);
}
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
if
(
act_sum
)
{
forward_with_act_sum
(
x
,
quantized_hidden_states_buffer
,
quantized_scale_buffer
,
quantized_sum_buffer
);
}
else
{
...
...
@@ -51,12 +53,18 @@ public:
}
private:
void
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
private:
const
bool
act_sum
;
const
bool
use_per_token_quant
;
const
float
variance_epsilon
;
Tensor
weight
;
};
\ No newline at end of file
};
src/pytorch_compat.h
View file @
57e50f8d
...
...
@@ -4,103 +4,106 @@
#include "Tensor.h"
namespace
pytorch_compat
{
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
assert
(
cond
);
}
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
assert
(
cond
);
}
template
<
typename
T
>
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
}
template
<
typename
T
>
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
}
namespace
at
{
using
::
Tensor
;
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
struct
Generator
{
Generator
()
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
std
::
mutex
mutex_
;
};
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
struct
StreamWrapper
{
cudaStream_t
st
;
cudaStream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
}
struct
CUDAGuard
{
int
dev
;
};
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
return
Generator
();
}
}
}
using
CUDAGeneratorImpl
=
Generator
;
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
namespace
at
{
using
::
Tensor
;
namespace
torch
{
using
at
::
kFloat32
;
using
at
::
kFloat
;
using
at
::
kFloat16
;
using
at
::
kBFloat16
;
using
at
::
kInt32
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
using
IntArrayRef
=
std
::
vector
<
int
>
;
using
TensorOptions
=
Tensor
::
TensorOptions
;
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
return
Tensor
::
empty_like
(
tensor
);
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
());
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
()).
zero_
();
}
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
}
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
int
a
;
int
b
;
};
}
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
struct
Generator
{
Generator
()
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
std
::
mutex
mutex_
;
};
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
namespace
c10
{
using
std
::
optional
;
struct
StreamWrapper
{
cudaStream_t
st
;
cudaStream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
}
struct
CUDAGuard
{
int
dev
;
};
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
return
Generator
();
}
}
// namespace detail
}
// namespace cuda
using
CUDAGeneratorImpl
=
Generator
;
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
// namespace at
namespace
torch
{
using
at
::
kFloat32
;
using
at
::
kFloat
;
using
at
::
kFloat16
;
using
at
::
kBFloat16
;
using
at
::
kInt32
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
using
IntArrayRef
=
std
::
vector
<
int
>
;
using
TensorOptions
=
Tensor
::
TensorOptions
;
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
return
Tensor
::
empty_like
(
tensor
);
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
());
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
()).
zero_
();
}
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
// namespace functional
}
// namespace nn
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
int
a
;
int
b
;
};
}
// namespace indexing
}
// namespace torch
namespace
c10
{
using
std
::
optional
;
}
}
// namespace pytorch_compat
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment