Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e9ad0535
Commit
e9ad0535
authored
Jan 23, 2025
by
muyangli
Browse files
[major] support SANA
parent
9eb2cee0
Changes
86
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1175 additions
and
11 deletions
+1175
-11
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
+5
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
+5
-0
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+432
-0
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+196
-0
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+517
-0
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+20
-11
No files found.
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
0 → 100644
View file @
e9ad0535
This diff is collapsed.
Click to expand it.
src/kernels/zgemm/gemm_w8a8.cu
0 → 100644
View file @
e9ad0535
#include "zgemm.h"
#include "gemm_w8a8.cuh"
namespace
nunchaku
::
kernels
{
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
assert
(
output
.
dtype
()
==
Tensor
::
INT8
);
assert
(
output
.
numel
()
/
output
.
shape
[
-
1
]
==
M
);
assert
(
output
.
shape
[
-
1
]
==
fuse_glu
?
K
/
2
:
K
);
assert
(
isTypeMatch
<
GEMM
::
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
1
);
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
assert
(
kernel
::
check
(
M
,
K
));
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
checkCUDA
(
cudaGetLastError
());
};
if
(
fuse_glu
)
{
launch
.
template
operator
()
<
true
>();
}
else
{
launch
.
template
operator
()
<
false
>();
}
}
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
K
=
act
.
shape
[
-
1
];
assert
(
K
==
wgt
.
shape
[
1
]);
int
actualM
=
0
;
int
actualN
=
0
;
if
(
out
.
valid
())
{
actualM
=
out
.
numel
()
/
out
.
shape
[
-
1
];
actualN
=
out
.
shape
[
-
1
];
assert
(
actualM
<=
M
&&
M
-
actualM
<
GEMM
::
BLOCK_M
);
assert
(
actualN
<=
N
&&
N
-
actualN
<
GEMM
::
BLOCK_N
);
}
auto
launch
=
[
&
]
<
typename
Epilogue
>
(
Epilogue
::
Arguments
args
)
{
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
bool
swapBlockMN
=
M
>
N
*
2
;
if
(
swapBlockMN
)
{
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>><<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
}
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
{}
});
};
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
});
}
#if 0
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
) {
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueLiteLA;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(M % batch_size == 0);
int batch_m = M / batch_size;
Epilogue::Arguments epilogueArgs;
epilogueArgs.batch_m = act.shape[1];
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
int, int, int,
Epilogue::Arguments,
bool,
bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(cudaGetLastError());
}
#endif
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w8a8.cuh
0 → 100644
View file @
e9ad0535
This diff is collapsed.
Click to expand it.
src/kernels/gemm
_w4a4
.h
→
src/kernels/
z
gemm
/zgemm
.h
View file @
e9ad0535
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment