Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
54e6d065
Commit
54e6d065
authored
Feb 20, 2025
by
muyangli
Browse files
[major] support NVFP4; upgrade to 0.1
parent
c7f41661
Changes
45
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
583 additions
and
125 deletions
+583
-125
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+412
-28
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+7
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+157
-91
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+2
-2
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+5
-2
No files found.
src/kernels/zgemm/gemm_w4a4.cuh
View file @
54e6d065
This diff is collapsed.
Click to expand it.
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
54e6d065
...
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_amscale_t
=
typename
GEMM
::
packed_amscale_t
;
using
packed_wmscale_t
=
typename
GEMM
::
packed_wmscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
...
...
@@ -38,9 +40,12 @@ public:
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
54e6d065
...
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
)
{
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
...
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
using
kernel
=
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
if
constexpr
(
USE_FP4
)
{
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
});
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
}
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
packed_wscale_t
>
(),
},
nextArgs
,
{}
assert
(
!
bias
.
valid
()
||
bias
.
numel
()
==
N
);
assert
(
!
wcscales
.
valid
()
||
wcscales
.
numel
()
==
N
);
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
EpilogueBias
::
Arguments
{
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
},
nextArgs
,
{}
});
});
});
};
// auto launch_bias = launch;
...
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
});
constexpr
bool
USE_UNSIGNED
=
true
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
shift_value
=
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
...
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
if
(
fp4
)
{
assert
(
oscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
*
4
);
}
else
{
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
}
const
int
rank
=
lora_down
.
shape
[
1
];
...
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
});
});
}
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
54e6d065
...
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
GEMM
::
EpilogueBias
<
true
,
false
>
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
...
...
src/kernels/zgemm/zgemm.h
View file @
54e6d065
...
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment