Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
92ac7b40
Commit
92ac7b40
authored
Mar 22, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Add our own FP16 Attention implementation
parent
182c323c
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
315 additions
and
153 deletions
+315
-153
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
+1
-1
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
+5
-0
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+201
-151
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+90
-0
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+18
-1
No files found.
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
→
src/kernels/zgemm/gemm_w4a4_launch_fp16
_fp4
.cu
View file @
92ac7b40
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
0 → 100644
View file @
92ac7b40
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
92ac7b40
...
...
@@ -3,11 +3,11 @@
namespace
nunchaku
::
kernels
{
#ifndef __INTELLISENSE__
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
gemm_w4a4
(
#else
template
<
>
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
>::
gemm_w4a4
(
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>::
gemm_w4a4
(
#endif
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
...
...
@@ -33,8 +33,17 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
#ifdef __INTELLISENSE__
static
constexpr
bool
USE_FP4
=
false
;
#endif
assert
(
fp4
==
USE_FP4
);
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
K
=
act
.
shape
[
-
1
]
*
2
;
...
...
@@ -71,90 +80,88 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
if
constexpr
(
USE_FP4
)
{
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
assert
(
alpha
==
1.0
f
);
return
;
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
});
if
constexpr
(
USE_FP4
)
{
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
...
...
@@ -262,30 +269,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
});
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
...
...
@@ -327,17 +332,54 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
assert
(
norm_k
.
valid
());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert
(
rotary_emb
.
scalar_type
()
==
Tensor
::
FP32
);
assert
(
rotary_emb
.
numel
()
==
M
*
GEMM
::
EpilogueQKVProj
::
HEAD_DIM
/
2
*
GEMM
::
EpilogueQKVProj
::
ROTARY_EMB_NUM_ELEMENTS
);
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueQKVProj
,
typename
GEMM
::
EpilogueNop
>(
typename
GEMM
::
EpilogueQKVProj
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
pool_out
=
poolout
.
valid
()
?
poolout
.
data_ptr
<
half_t
>
()
:
nullptr
,
.
rotary_emb
=
rotary_emb
.
data_ptr
<
float
>
(),
assert
(
rotary_emb
.
ndims
()
==
3
);
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
2
]
==
GEMM
::
EpilogueRMSNormRope
::
HEAD_DIM
);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
// .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
// .rotary_emb = rotary_emb.data_ptr<float>(),
// .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
// .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
// .epsilon = 1e-6,
// }, {});
using
EpilogueRope
=
typename
GEMM
::
EpilogueRMSNormRope
;
auto
argsRope
=
typename
GEMM
::
EpilogueRMSNormRope
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
epsilon
=
1e-6
,
},
{});
};
if
(
out_q
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
},
{});
}
else
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
},
{});
}
}
else
if
(
out
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueDefault
;
...
...
@@ -357,8 +399,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
}
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueLiteLA
;
int
batch_size
=
vk
.
shape
[
0
];
...
...
@@ -384,8 +426,8 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
...
@@ -418,38 +460,41 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
});
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
if
constexpr
(
USE_FP4
)
{
assert
(
false
);
// not implemented
return
;
}
int
M
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
...
...
@@ -471,8 +516,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Te
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
if
constexpr
(
USE_FP4
)
{
assert
(
false
);
return
;
}
int
N
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
...
...
src/kernels/zgemm/gemm_w4a4_test.cu
0 → 100644
View file @
92ac7b40
#include "zgemm.h"
#include "gemm_w4a4.cuh"
namespace
nunchaku
::
kernels
{
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
)
{
assert
(
input
.
ndims
()
==
2
);
const
int
M
=
input
.
shape
[
0
];
const
int
N
=
input
.
shape
[
1
];
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
using
GEMM
=
GEMM_W4A4
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpilogueRMSNormRope
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
assert
(
N
%
GEMM
::
BLOCK_N
==
0
);
using
kernel
=
typename
GEMM
::
test_epilogue_kernel
<
Epilogue
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
Epilogue
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
}
}
);
checkCUDA
(
cudaGetLastError
());
}
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
assert
(
input
.
ndims
()
==
2
);
const
int
M
=
input
.
shape
[
0
];
const
int
N
=
input
.
shape
[
1
];
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
Tensor
output
=
Tensor
::
empty_like
(
input
);
using
GEMM
=
GEMM_W4A4
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpiloguePackQKV
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
assert
(
N
%
GEMM
::
BLOCK_N
==
0
);
using
kernel
=
typename
GEMM
::
test_epilogue_kernel
<
Epilogue
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
M
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
actualM
=
numTokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
}
);
checkCUDA
(
cudaGetLastError
());
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/zgemm.h
View file @
92ac7b40
...
...
@@ -30,7 +30,11 @@ void gemm_w4a4(
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
...
...
@@ -57,4 +61,17 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
);
// FOR TEST ONLY
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
);
};
// namespace nunchaku::kernels
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment