Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
54e6d065
Commit
54e6d065
authored
Feb 20, 2025
by
muyangli
Browse files
[major] support NVFP4; upgrade to 0.1
parent
c7f41661
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
583 additions
and
125 deletions
+583
-125
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+412
-28
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+7
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+157
-91
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+2
-2
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+5
-2
No files found.
src/kernels/zgemm/gemm_w4a4.cuh
View file @
54e6d065
#pragma once
#pragma once
#include "gemm_base.cuh"
#include "gemm_base.cuh"
// #include "gemm_w4a4_block.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
...
@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
...
@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
public:
public:
IMPORT_GEMM_BASE
(
Config
);
IMPORT_GEMM_BASE
(
Config
);
public:
// micro-scales for FP4 MMA
// each uint32_t is a 4*32 matrix of scales (for MMA of 64*32)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200
static
constexpr
bool
FP4_AVAILABLE
=
true
;
#else
static
constexpr
bool
FP4_AVAILABLE
=
false
;
#endif
__device__
__forceinline__
static
void
trap_no_fp4
()
{
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"FP4 is not available on this device
\n
"
);
}
__syncthreads
();
__nanosleep
(
1000000
);
__trap
();
}
static_assert
(
WARP_N
%
32
==
0
);
static_assert
(
WARP_M
%
32
==
0
);
static
constexpr
int
WMSCALES_PACK_SIZE
=
clamp
(
WARP_N
/
32
,
1
,
4
);
static
constexpr
int
WMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_N
/
32
,
WMSCALES_PACK_SIZE
);
static
constexpr
int
WMSCALES_VALID_LANES
=
WARP_SIZE
;
static
constexpr
int
AMSCALES_PACK_SIZE
=
clamp
(
WARP_M
/
32
,
1
,
4
);
static
constexpr
int
AMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_M
/
32
,
AMSCALES_PACK_SIZE
);
static
constexpr
int
AMSCALES_VALID_LANES
=
WARP_SIZE
;
struct
packed_wmscale_t
{
uint32_t
data
[
WMSCALES_PACK_SIZE
];
};
struct
packed_amscale_t
{
uint32_t
data
[
AMSCALES_PACK_SIZE
];
};
using
amscale_warp
=
std
::
array
<
packed_amscale_t
,
AMSCALES_NUM_PACKS
>
;
using
wmscale_warp
=
std
::
array
<
packed_wmscale_t
,
WMSCALES_NUM_PACKS
>
;
// amscales: [M / BLOCK_M, K / group size, NUM_WARPS, AMSCALES_NUM_PACKS, WARP_SIZE] of packed_amscale_t
__device__
__forceinline__
static
void
load_amscale
(
const
packed_amscale_t
*
ptr
,
int
group
,
amscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
AMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
+
i
*
AMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
// wmscales: [N / BLOCK_N, 1, K / group size, WMSCALES_NUM_PACKS, WMSCALES_VALID_LANES] of packed_wmscale_t
__device__
__forceinline__
static
void
load_wmscale
(
const
packed_wmscale_t
*
ptr
,
int
group
,
wmscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
WMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
WMSCALES_NUM_PACKS
+
i
)
*
WMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
__device__
__forceinline__
static
void
quantize_w4a4_fp4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
uint32_t
&
output_scale
,
int
ida
)
{
constexpr
int
NUM_GROUPS
=
4
;
static_assert
(
NUM_GROUPS
==
INSN_K
/
INSN_N
);
constexpr
float
QVALUE_MAX
=
6.0
f
;
constexpr
float
RECPI_QVALUE_MAX
=
1
/
QVALUE_MAX
;
constexpr
float
MSCALE_MAX
=
448.0
f
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// 0 for row 0-7; 1 for row 8-15
// each half2_t represents a 8*8 matrix
half2_t
input
[
2
][
INSN_K
/
INSN_N
*
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_N
;
i
++
)
{
input
[
0
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
0
];
input
[
0
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
2
];
input
[
1
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
1
];
input
[
1
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
3
];
}
auto
maxabs
=
[](
half2_t
val
)
ALWAYSINLINE
{
val
=
__habs2
(
val
);
return
__hmax
(
val
.
x
,
val
.
y
);
};
// each half_t represents maxvalue in a 8*16 matrix
half_t
maxvalue
[
2
][
NUM_GROUPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxabs
(
input
[
0
][
i
*
2
]),
maxabs
(
input
[
0
][
i
*
2
+
1
]));
maxvalue
[
1
][
i
]
=
__hmax
(
maxabs
(
input
[
1
][
i
*
2
]),
maxabs
(
input
[
1
][
i
*
2
+
1
]));
}
#pragma unroll
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxvalue
[
0
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
0
][
i
],
mask
));
maxvalue
[
1
][
i
]
=
__hmax
(
maxvalue
[
1
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
1
][
i
],
mask
));
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
float
scale
[
2
][
NUM_GROUPS
];
float
rscale
[
2
][
NUM_GROUPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
scale
[
0
][
i
]
=
fminf
(
float
(
maxvalue
[
0
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
scale
[
1
][
i
]
=
fminf
(
float
(
maxvalue
[
1
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
// TODO: check whether (1 / scale) or (1 / fp8scale) is better
rscale
[
0
][
i
]
=
cuda_frcp
(
scale
[
0
][
i
]);
rscale
[
1
][
i
]
=
cuda_frcp
(
scale
[
1
][
i
]);
}
uint32_t
fp8scale
[
2
];
fp8scale
[
0
]
=
quantize_float4_fp8
(
make_float4
(
scale
[
0
][
0
],
scale
[
0
][
1
],
scale
[
0
][
2
],
scale
[
0
][
3
]));
fp8scale
[
1
]
=
quantize_float4_fp8
(
make_float4
(
scale
[
1
][
0
],
scale
[
1
][
1
],
scale
[
1
][
2
],
scale
[
1
][
3
]));
/**
* output_scale pack format: (ida=0)
* lane 0 => row 0 if ida==0
* lane 1 => row 8 if ida==0
* lane 2 => row 0 if ida==1
* lane 3 => row 8 if ida==1
* ...
* lane i => quad (i/4) => row (i/4+8*(i%2)) if (i%4/2==ida) => srclane i, index i%2
*/
if
(
laneId
%
4
/
2
==
ida
)
{
output_scale
=
(
laneId
%
2
==
0
)
?
fp8scale
[
0
]
:
fp8scale
[
1
];
}
uint32_t
qpacks
[
2
][
INSN_K
/
INSN_M
*
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
][
i
/
2
],
rscale
[
j
][
i
/
2
]);
qpacks
[
j
][
i
]
=
quantize_float2_fp4
(
fval
)
<<
(
laneId
%
4
*
8
);
}
}
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
qpacks
[
j
][
i
]
|=
__shfl_xor_sync
(
~
0
,
qpacks
[
j
][
i
],
mask
);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
if
(
laneId
%
4
==
i
)
{
output
.
x
=
qpacks
[
0
][
0
+
i
];
output
.
y
=
qpacks
[
1
][
0
+
i
];
output
.
z
=
qpacks
[
0
][
4
+
i
];
output
.
w
=
qpacks
[
1
][
4
+
i
];
}
}
}
// m16n16k64 MMA
// ida, idb in {0, 1}
__device__
__forceinline__
static
packed_f32psum_t
mma_fp4
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_f32psum_t
psum
,
uint32_t
amscale
,
uint32_t
wmscale
,
int
ida
,
int
idb
)
{
packed_f32psum_t
out
;
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
0
]),
"=f"
(
out
.
data
[
1
]),
"=f"
(
out
.
data
[
2
]),
"=f"
(
out
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"f"
(
psum
.
data
[
0
]),
"f"
(
psum
.
data
[
1
]),
"f"
(
psum
.
data
[
2
]),
"f"
(
psum
.
data
[
3
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
))
);
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
4
]),
"=f"
(
out
.
data
[
5
]),
"=f"
(
out
.
data
[
6
]),
"=f"
(
out
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"f"
(
psum
.
data
[
4
]),
"f"
(
psum
.
data
[
5
]),
"f"
(
psum
.
data
[
6
]),
"f"
(
psum
.
data
[
7
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
+
1
))
);
return
out
;
}
__device__
__forceinline__
static
void
compute_fp4
(
act_warp
A
,
wgt_warp
W
,
amscale_warp
amscale
,
wmscale_warp
wmscale
,
f32psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma_fp4
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
],
amscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
wmscale
[
j
/
2
/
WMSCALES_PACK_SIZE
].
data
[
j
/
2
%
WMSCALES_PACK_SIZE
],
i
%
2
,
j
%
2
);
}
}
}
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
__device__
__forceinline__
static
void
gemm_w4a4_fp4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
// per-tensor scale of weight
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8 * 2
wgt_warp
W
[
NUM_STAGES
];
// 32 * 2
amscale_warp
amscale
[
NUM_STAGES
];
// 1 * 2
wmscale_warp
wmscale
[
NUM_STAGES
];
// 4 * 2
f32psum_warp
fpsum
;
// 128
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
load_amscale
(
ascales
,
k
,
amscale
[
k
],
true
);
load_wmscale
(
wscales
,
k
,
wmscale
[
k
],
true
);
}
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
=
0
;
}
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
load_amscale
(
ascales
,
nextk
,
amscale
[
idx
],
pred
);
load_wmscale
(
wscales
,
nextk
,
wmscale
[
idx
],
pred
);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
compute_fp4
(
A
[
k2
],
W
[
k2
],
amscale
[
k2
],
wmscale
[
k2
],
fpsum
);
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
// asm volatile ("membar.cta;");
}
}
unused_var
(
dummy
,
alwaysfalse
);
if
constexpr
(
USE_ALPHA
)
{
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
*=
alpha
;
}
}
}
auto
f16psum
=
packed_fp32_to_fp16
(
fpsum
);
CHECK_NAN
(
f16psum
,
"f16psum"
);
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
}
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
struct
gemm_w4a4_fp4_kernel
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
if
(
swapBlockXY
)
{
std
::
swap
(
binfo
.
bm
,
binfo
.
bn
);
std
::
swap
(
binfo
.
numBlocksM
,
binfo
.
numBlocksN
);
}
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
constexpr
(
FP4_AVAILABLE
)
{
gemm_w4a4_fp4_block
<
Epilogue
,
USE_ALPHA
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
,
wscales
+
bn
*
(
K
/
WARP_K
)
*
WMSCALES_NUM_PACKS
*
WMSCALES_VALID_LANES
,
alpha
,
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
}
else
{
trap_no_fp4
();
}
}
};
public:
public:
template
<
bool
ACT_UNSIGNED
>
template
<
bool
ACT_UNSIGNED
>
__device__
__forceinline__
__device__
__forceinline__
...
@@ -416,7 +780,7 @@ public:
...
@@ -416,7 +780,7 @@ public:
template
<
bool
ACT_UNSIGNED
,
typename
T
>
template
<
bool
ACT_UNSIGNED
,
typename
T
>
__device__
__forceinline__
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
apply_scales
([
&
](
int
i
,
int
j
)
{
apply_scales
<
true
>
([
&
](
int
i
,
int
j
)
{
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
},
ascale
,
wscale
,
fpsum
);
},
ascale
,
wscale
,
fpsum
);
}
}
...
@@ -530,6 +894,10 @@ public:
...
@@ -530,6 +894,10 @@ public:
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#if 0
fpsum_warp fpsum;
GEMM_W4A4_Block<Config>()(act, wgt, ascales, wscales, K, fpsum, alwaysfalse);
#else
act_warp
A
[
NUM_STAGES
];
// 8
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
[
NUM_STAGES
];
// 1
ascale_warp
ascale
[
NUM_STAGES
];
// 1
...
@@ -591,6 +959,8 @@ public:
...
@@ -591,6 +959,8 @@ public:
unused_var
(
dummy
,
alwaysfalse
);
unused_var
(
dummy
,
alwaysfalse
);
#endif
#if 0
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
#else
...
@@ -602,11 +972,13 @@ public:
...
@@ -602,11 +972,13 @@ public:
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
}
}
template
<
bool
FUSE_GELU
,
bool
USE_UNSIGNED
>
template
<
bool
FUSE_GELU
,
bool
USE_UNSIGNED
,
bool
USE_FP4
>
struct
EpilogueQuantize
{
struct
EpilogueQuantize
{
using
oscales_t
=
typename
std
::
conditional_t
<
USE_FP4
,
packed_amscale_t
,
packed_ascale_t
>
;
struct
Arguments
{
struct
Arguments
{
packed_act_t
*
qout
;
packed_act_t
*
qout
;
packed_a
scale_t
*
oscales
;
o
scale
s
_t
*
oscales
;
half_t
shift_value
;
half_t
shift_value
;
const
packed_wscale_t
*
smooth_factor
;
const
packed_wscale_t
*
smooth_factor
;
...
@@ -616,7 +988,7 @@ public:
...
@@ -616,7 +988,7 @@ public:
static
constexpr
int
NUM_GROUPS
=
WARP_N_TILES
/
NUM_PACKS
;
static
constexpr
int
NUM_GROUPS
=
WARP_N_TILES
/
NUM_PACKS
;
__device__
__forceinline__
__device__
__forceinline__
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
packed_a
scale_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
o
scale
s
_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -627,6 +999,8 @@ public:
...
@@ -627,6 +999,8 @@ public:
#pragma unroll
#pragma unroll
for
(
int
group
=
0
;
group
<
NUM_GROUPS
;
group
++
)
{
for
(
int
group
=
0
;
group
<
NUM_GROUPS
;
group
++
)
{
amscale_warp
omscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
tmp
[
NUM_PACKS
];
packed_fpsum_t
tmp
[
NUM_PACKS
];
...
@@ -652,15 +1026,6 @@ public:
...
@@ -652,15 +1026,6 @@ public:
// dst = src;
// dst = src;
}
}
// auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
// float2 af = half22float2(a);
// float2 bf = half22float2(b);
// float2 of;
// of.x = __fdividef(af.x, bf.x);
// of.y = __fdividef(af.y, bf.y);
// return float22half2<half2_t>(of);
// };
tmp
[
j
].
data
[
0
]
=
h2div
(
tmp
[
j
].
data
[
0
],
ws1
);
tmp
[
j
].
data
[
0
]
=
h2div
(
tmp
[
j
].
data
[
0
],
ws1
);
tmp
[
j
].
data
[
1
]
=
h2div
(
tmp
[
j
].
data
[
1
],
ws1
);
tmp
[
j
].
data
[
1
]
=
h2div
(
tmp
[
j
].
data
[
1
],
ws1
);
tmp
[
j
].
data
[
2
]
=
h2div
(
tmp
[
j
].
data
[
2
],
ws2
);
tmp
[
j
].
data
[
2
]
=
h2div
(
tmp
[
j
].
data
[
2
],
ws2
);
...
@@ -668,13 +1033,26 @@ public:
...
@@ -668,13 +1033,26 @@ public:
}
}
packed_act_t
qresult
;
packed_act_t
qresult
;
quantize_w4a4_from_fpsum_warp
<
USE_UNSIGNED
>
(
tmp
,
qresult
,
&
oscale_shmem
[
warpId
][
i
*
INSN_M
]);
if
constexpr
(
USE_FP4
)
{
quantize_w4a4_fp4_from_fpsum_warp
(
tmp
,
qresult
,
omscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
i
%
2
);
}
else
{
quantize_w4a4_from_fpsum_warp
<
USE_UNSIGNED
>
(
tmp
,
qresult
,
&
oscale_shmem
[
warpId
][
i
*
INSN_M
]);
}
store
(
&
qout
[((
group
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
],
qresult
);
store
(
&
qout
[((
group
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
],
qresult
);
}
}
__syncwarp
();
if
constexpr
(
USE_FP4
)
{
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
#pragma unroll
__syncwarp
();
for
(
int
k
=
0
;
k
<
AMSCALES_NUM_PACKS
;
k
++
)
{
store
(
&
oscales
[((
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
+
k
)
*
AMSCALES_VALID_LANES
+
laneId
],
omscale
[
k
]);
}
}
if
constexpr
(
!
USE_FP4
)
{
__syncwarp
();
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
__syncwarp
();
}
}
}
}
}
...
@@ -683,13 +1061,18 @@ public:
...
@@ -683,13 +1061,18 @@ public:
const
int
bm
=
binfo
.
bm
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
apply_quantize
(
if
constexpr
(
!
USE_FP4
||
FP4_AVAILABLE
)
{
fpsum
,
M
,
N
,
K
,
apply_quantize
(
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
fpsum
,
M
,
N
,
K
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
args
.
shift_value
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
(
USE_FP4
?
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
:
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
),
);
args
.
shift_value
,
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
else
{
trap_no_fp4
();
}
}
}
};
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
...
@@ -937,8 +1320,10 @@ public:
...
@@ -937,8 +1320,10 @@ public:
}
}
};
};
template
<
bool
fuse_glu
>
template
<
bool
fuse_glu
,
bool
use_fp4
>
struct
quantize_w4a4_fuse_lora_kernel
{
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
...
@@ -946,7 +1331,7 @@ public:
...
@@ -946,7 +1331,7 @@ public:
const
half_t
*
input
;
const
half_t
*
input
;
const
packed_wscale_t
*
smooth_factor
;
const
packed_wscale_t
*
smooth_factor
;
packed_act_t
*
output
;
packed_act_t
*
output
;
packed_a
scale_t
*
oscales
;
o
scale
s
_t
*
oscales
;
const
packed_fpsum_t
*
lora_wgt_down
;
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
float
*
lora_act
;
...
@@ -999,7 +1384,7 @@ public:
...
@@ -999,7 +1384,7 @@ public:
.
lora_act
=
args
.
lora_act
,
.
lora_act
=
args
.
lora_act
,
});
});
EpilogueQuantize
<
false
,
false
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
>::
Arguments
{
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{
.
qout
=
args
.
output
,
.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
shift_value
=
0
,
...
@@ -1488,7 +1873,6 @@ public:
...
@@ -1488,7 +1873,6 @@ public:
);
);
}
}
};
};
};
};
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
54e6d065
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_amscale_t
=
typename
GEMM
::
packed_amscale_t
;
using
packed_wmscale_t
=
typename
GEMM
::
packed_wmscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
using
half_t
=
typename
GEMM
::
half_t
;
...
@@ -38,9 +40,12 @@ public:
...
@@ -38,9 +40,12 @@ public:
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
);
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
54e6d065
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
)
{
)
{
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
N
=
wgt
.
shape
[
0
];
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std
::
swap
(
grid
.
x
,
grid
.
y
);
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// }, args);
using
kernel
=
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
auto
func
=
invoke_kernel
<
kernel
,
if
constexpr
(
!
USE_FP4
)
{
const
packed_act_t
*
,
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
const
packed_wgt_t
*
,
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_ascale_t
*
,
const
packed_act_t
*
,
const
packed_wscale_t
*
,
const
packed_wgt_t
*
,
int
,
int
,
int
,
const
packed_ascale_t
*
,
typename
Epilogue
::
Arguments
,
const
packed_wscale_t
*
,
bool
,
int
,
int
,
int
,
bool
>
;
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
if
(
shmem
>=
24
*
1024
)
{
if
constexpr
(
USE_FP4
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
act
.
data_ptr
<
packed_act_t
>
(),
// throw std::runtime_error("FP4 kernel is not available");
wgt
.
data_ptr
<
packed_wgt_t
>
(),
// }
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
});
};
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
assert
(
!
bias
.
valid
()
||
bias
.
numel
()
==
N
);
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
assert
(
!
wcscales
.
valid
()
||
wcscales
.
numel
()
==
N
);
}
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
assert
(
bias
.
numel
()
==
N
);
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueBias
::
Arguments
{
typename
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
packed_wscale_t
>
(),
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
},
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
nextArgs
,
},
{}
nextArgs
,
{}
});
});
});
});
};
};
// auto launch_bias = launch;
// auto launch_bias = launch;
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
});
constexpr
bool
USE_UNSIGNED
=
true
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
shift_value
=
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
assert
(
out_vk
.
valid
());
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
}
}
template
<
typename
Config
>
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
if
(
fp4
)
{
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
assert
(
oscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
*
4
);
}
else
{
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
}
const
int
rank
=
lora_down
.
shape
[
1
];
const
int
rank
=
lora_down
.
shape
[
1
];
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
>
;
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
.
input
=
input
.
data_ptr
<
half_t
>
(),
typename
kernel
::
Arguments
{
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
M
=
M
,
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
N
=
N
,
.
M
=
M
,
.
actualM
=
actualM
,
.
N
=
N
,
.
actualN
=
actualN
,
.
actualM
=
actualM
,
}
.
actualN
=
actualN
,
);
}
checkCUDA
(
cudaGetLastError
());
);
checkCUDA
(
cudaGetLastError
());
});
});
});
});
});
}
}
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
54e6d065
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
GEMM
::
EpilogueBias
<
true
,
false
>
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
},
nextArgs
,
nextArgs
,
...
...
src/kernels/zgemm/zgemm.h
View file @
54e6d065
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
);
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment