Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e9ad0535
Commit
e9ad0535
authored
Jan 23, 2025
by
muyangli
Browse files
[major] support SANA
parent
9eb2cee0
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1175 additions
and
11 deletions
+1175
-11
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
+5
-0
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
+5
-0
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+432
-0
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+196
-0
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+517
-0
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+20
-11
No files found.
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16.cu
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4_launch.cuh"
namespace
nunchaku
::
kernels
{
#ifndef __INTELLISENSE__
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
#else
template
<
>
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
>::
gemm_w4a4
(
#endif
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
)
{
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
K
=
act
.
shape
[
-
1
]
*
2
;
assert
(
K
==
wgt
.
shape
[
1
]
*
2
);
int
actualM
=
0
;
int
actualN
=
0
;
if
(
out
.
valid
())
{
actualM
=
out
.
numel
()
/
out
.
shape
[
-
1
];
actualN
=
out
.
shape
[
-
1
];
assert
(
actualM
<=
M
&&
M
-
actualM
<
GEMM
::
BLOCK_M
);
assert
(
actualN
<=
N
&&
N
-
actualN
<
GEMM
::
BLOCK_N
);
}
spdlog
::
trace
(
"gemm_w4a4: M={} N={} K={}"
,
M
,
N
,
K
);
spdlog
::
trace
(
"act at {}"
,
act
.
data_ptr
());
spdlog
::
trace
(
"wgt at {}"
,
wgt
.
data_ptr
());
spdlog
::
trace
(
"ascales at {}"
,
ascales
.
data_ptr
());
spdlog
::
trace
(
"wscales at {}"
,
wscales
.
data_ptr
());
if
(
bias
.
valid
())
{
spdlog
::
trace
(
"bias at {}"
,
bias
.
data_ptr
());
}
int
shmem
=
0
;
auto
launch
=
[
&
]
<
typename
Epilogue
>
(
Epilogue
::
Arguments
args
)
{
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
assert
(
N
%
GEMM
::
BLOCK_N
==
0
);
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
bool
swapBlockMN
=
M
>
N
*
2
;
if
(
swapBlockMN
)
{
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
using
kernel
=
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
}
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
packed_wscale_t
>
(),
},
nextArgs
,
{}
});
};
// auto launch_bias = launch;
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
if
(
!
lora_up
.
valid
())
{
assert
(
!
lora_down
.
valid
());
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
({
midArgs
,
nextArgs
});
}
const
int
rank_up
=
lora_up
.
shape
[
1
];
assert
(
lora_up
.
shape
[
0
]
==
N
);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_in
.
shape
[
0
]
==
M
);
assert
(
lora_act_in
.
shape
[
1
]
==
rank_up
);
dispatchVal
(
rank_up
,
LoraRanks
(),
[
&
]
<
int
RANK_UP
>
()
{
using
LoraUp
=
typename
GEMM
::
Lora
<
RANK_UP
>
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
scale_t
scales
;
if
constexpr
(
scales
.
size
()
>
0
)
{
assert
(
lora_scales
.
size
()
>=
scales
.
size
());
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
i
++
)
{
scales
[
i
]
=
lora_scales
[
i
];
}
}
if
(
!
lora_down
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
scales
=
scales
,
},
midArgs
,
nextArgs
,
{}
});
}
const
int
rank_down
=
lora_down
.
shape
[
1
];
assert
(
rank_down
==
rank_up
);
assert
(
lora_down
.
shape
[
0
]
==
N
);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_out
.
shape
[
0
]
==
M
);
assert
(
lora_act_out
.
shape
[
1
]
==
rank_down
);
lora_act_out
.
zero_
();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
scales
=
scales
,
},
midArgs
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
},
nextArgs
,
{}
});
// });
});
};
if
(
qout
.
valid
()
&&
oscales
.
valid
())
{
// dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
constexpr
bool
USE_UNSIGNED
=
true
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
shift_value
=
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
using
Epilogue
=
typename
GEMM
::
EpilogueLiteLA
;
assert
(
out_vk
.
dtype
()
==
Tensor
::
FP32
);
assert
(
out_vk
.
ndims
()
==
4
);
assert
(
out_vk
.
shape
[
2
]
==
Epilogue
::
LITELA_HEAD_DIM
+
1
);
assert
(
out_vk
.
shape
[
3
]
==
Epilogue
::
LITELA_HEAD_DIM
);
assert
(
out_vk
.
shape
[
1
]
*
Epilogue
::
LITELA_HEAD_DIM
*
3
==
N
);
int
batch_size
=
out_vk
.
shape
[
0
];
int
num_heads
=
out_vk
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
out_linearattn
.
dtype
()));
assert
(
out_linearattn
.
ndims
()
==
3
);
assert
(
out_linearattn
.
shape
[
0
]
==
batch_size
);
assert
(
out_linearattn
.
shape
[
2
]
*
3
==
N
);
int
num_tokens
=
out_linearattn
.
shape
[
1
];
assert
(
num_tokens
%
GEMM
::
BLOCK_M
==
0
);
int
num_blocks_per_batch
=
ceilDiv
(
num_tokens
,
GEMM
::
BLOCK_M
);
shmem
=
std
::
max
(
shmem
,
Epilogue
::
SHMEM_SIZE
);
out_vk
.
zero_
();
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
num_blocks_per_batch
=
num_blocks_per_batch
,
.
actualM
=
M
,
},
{});
}
else
if
(
rotary_emb
.
valid
())
{
assert
(
norm_q
.
valid
());
assert
(
norm_k
.
valid
());
// assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
assert
(
rotary_emb
.
scalar_type
()
==
Tensor
::
FP32
);
assert
(
rotary_emb
.
numel
()
==
M
*
GEMM
::
EpilogueQKVProj
::
HEAD_DIM
/
2
*
GEMM
::
EpilogueQKVProj
::
ROTARY_EMB_NUM_ELEMENTS
);
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueQKVProj
,
typename
GEMM
::
EpilogueNop
>(
typename
GEMM
::
EpilogueQKVProj
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
pool_out
=
poolout
.
valid
()
?
poolout
.
data_ptr
<
half_t
>
()
:
nullptr
,
.
rotary_emb
=
rotary_emb
.
data_ptr
<
float
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
epsilon
=
1e-6
,
},
{});
}
else
if
(
out
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueDefault
;
typename
Epilogue
::
Arguments
args
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
};
if
(
fuse_silu
)
{
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueSilu
>(
args
,
{});
}
else
{
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
args
,
{});
}
}
else
{
assert
(
false
);
}
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueLiteLA
;
int
batch_size
=
vk
.
shape
[
0
];
int
num_heads
=
vk
.
shape
[
1
];
int
num_tokens
=
q
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
q
.
scalar_type
()));
assert
(
vk
.
scalar_type
()
==
Tensor
::
FP32
);
int
BLOCK_SIZE
;
if
(
num_tokens
%
256
==
0
)
{
BLOCK_SIZE
=
256
;
}
else
{
BLOCK_SIZE
=
128
;
}
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
><<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
const
int
M
=
ceilDiv
(
actualM
,
GEMM
::
BLOCK_M
)
*
GEMM
::
BLOCK_M
;
const
int
N
=
ceilDiv
(
actualN
/
(
fuse_glu
?
2
:
1
),
GEMM
::
BLOCK_N
)
*
GEMM
::
BLOCK_N
;
assert
(
output
.
dtype
()
==
Tensor
::
INT8
);
assert
(
output
.
numel
()
/
output
.
shape
[
-
1
]
==
M
);
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
const
int
rank
=
lora_down
.
shape
[
1
];
assert
(
lora_down
.
shape
[
0
]
==
N
);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_out
.
shape
[
0
]
==
M
);
assert
(
lora_act_out
.
shape
[
1
]
==
rank
);
lora_act_out
.
zero_
();
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
});
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
int
M
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
assert
(
output
.
dtype
()
==
Tensor
::
INT8
);
assert
(
output
.
numel
()
/
output
.
shape
[
-
1
]
==
M
);
assert
(
output
.
shape
[
-
1
]
==
K
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
int
N
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
assert
(
output
.
dtype
()
==
Tensor
::
INT8
);
assert
(
output
.
ndims
()
==
2
);
assert
(
output
.
shape
[
0
]
==
N
);
assert
(
output
.
shape
[
1
]
==
K
/
2
);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
// assert(oscales.dtype() == Tensor::FP16);
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w8a8.cu
0 → 100644
View file @
e9ad0535
#include "zgemm.h"
#include "gemm_w8a8.cuh"
namespace
nunchaku
::
kernels
{
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
input
.
numel
()
/
input
.
shape
[
-
1
];
int
K
=
input
.
shape
[
-
1
];
assert
(
output
.
dtype
()
==
Tensor
::
INT8
);
assert
(
output
.
numel
()
/
output
.
shape
[
-
1
]
==
M
);
assert
(
output
.
shape
[
-
1
]
==
fuse_glu
?
K
/
2
:
K
);
assert
(
isTypeMatch
<
GEMM
::
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
1
);
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
assert
(
kernel
::
check
(
M
,
K
));
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
checkCUDA
(
cudaGetLastError
());
};
if
(
fuse_glu
)
{
launch
.
template
operator
()
<
true
>();
}
else
{
launch
.
template
operator
()
<
false
>();
}
}
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
K
=
act
.
shape
[
-
1
];
assert
(
K
==
wgt
.
shape
[
1
]);
int
actualM
=
0
;
int
actualN
=
0
;
if
(
out
.
valid
())
{
actualM
=
out
.
numel
()
/
out
.
shape
[
-
1
];
actualN
=
out
.
shape
[
-
1
];
assert
(
actualM
<=
M
&&
M
-
actualM
<
GEMM
::
BLOCK_M
);
assert
(
actualN
<=
N
&&
N
-
actualN
<
GEMM
::
BLOCK_N
);
}
auto
launch
=
[
&
]
<
typename
Epilogue
>
(
Epilogue
::
Arguments
args
)
{
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
bool
swapBlockMN
=
M
>
N
*
2
;
if
(
swapBlockMN
)
{
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>><<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
}
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
{}
});
};
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
});
}
#if 0
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
) {
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueLiteLA;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(M % batch_size == 0);
int batch_m = M / batch_size;
Epilogue::Arguments epilogueArgs;
epilogueArgs.batch_m = act.shape[1];
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
int, int, int,
Epilogue::Arguments,
bool,
bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(cudaGetLastError());
}
#endif
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w8a8.cuh
0 → 100644
View file @
e9ad0535
#pragma once
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
class
GEMM_W8A8
:
public
GEMMBase
<
GEMMConfig_W8A8
>
{
public:
using
psum_warp
=
std
::
array
<
packed_psum_t
,
WARP_M_TILES
*
WARP_N_TILES
>
;
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
// packed_psum_t psum;
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
])
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
])
);
return
psum
;
}
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
]);
}
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 32) matrix
* input is per-warp (in global memory / shared memory)
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
template
<
bool
input_shmem
=
false
>
__device__
__forceinline__
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
QUANTIZE_BITWIDTH
=
8
;
// constexpr int QUANTIZE_BITMASK = 0xff;
// constexpr int QVALUE_MAX = 128; // 4 bit => [-128, 127]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_input
packs
[
NUM_PACKWARPS
];
// load
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
packs
[
i
]
=
load
<
input_shmem
>
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
}
// quantize
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
const
int
row
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
const
int
col
=
laneId
%
NUM_PACKS_PER_ROW
;
float
rscale
=
cuda_frcp
(
float
(
oscales
[
row
]));
uint32_t
qpack
=
0
;
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2
fval
=
half22float2
(
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]))
*
float2
(
rscale
,
rscale
);
qpack
|=
quantize_float2
<
QUANTIZE_BITWIDTH
,
false
>
(
fval
)
<<
(
j
*
QUANTIZE_BITWIDTH
);
}
mat
[
row
][
col
]
=
qpack
;
}
__syncwarp
();
// convert to imma format
int
row
=
laneId
%
16
;
int
col
=
laneId
/
16
*
4
;
ldmatrix
(
&
mat
[
row
][
col
],
output
);
__syncwarp
();
}
/**
* each warp finds absmax from a row
*/
template
<
bool
fuse_glu
=
false
>
__device__
__forceinline__
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
using
packed_gated_input
=
std
::
array
<
half_t
,
4
>
;
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
NUM_STAGES
=
2
;
half2_t
maxvalue2
=
{
0
,
0
};
packed_input
pack
[
NUM_STAGES
];
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
const
int
idx
=
k
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
if
(
idx
<
K
)
{
pack
[
k
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
&
input
[
idx
]));
}
else
{
pack
[
k
].
fill
(
half2_t
(
0
,
0
));
}
}
// int dummy = 0;
// FIXME: pipeline does not work
// TODO: store quantized data to shmem (instead of half)
for
(
int
k1
=
0
;
k1
<
ceilDiv
(
K
,
PACK_SIZE
*
WARP_SIZE
);
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
const
int
nextidx
=
(
k1
+
k2
+
NUM_STAGES
-
1
)
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
const
int
nextk2
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
if
(
nextidx
<
K
)
{
pack
[
nextk2
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
&
input
[
nextidx
]));
}
else
{
pack
[
nextk2
].
fill
(
half2_t
(
0
,
0
));
}
packed_input
&
p
=
pack
[
k2
];
if
constexpr
(
fuse_glu
)
{
packed_gated_input
gated
;
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
gated
[
j
]
=
p
[
j
].
x
*
gelu_half
(
p
[
j
].
y
);
p
[
j
].
x
=
gated
[
j
];
p
[
j
].
y
=
0
;
}
int
idx
=
(
k1
+
k2
)
*
PACK_SIZE
/
2
*
WARP_SIZE
+
laneId
*
PACK_SIZE
/
2
;
if
(
idx
<
K
)
{
store
<
true
>
(
reinterpret_cast
<
packed_gated_input
*>
(
&
output_shmem
[
idx
]),
gated
);
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__habs2
(
p
[
j
]));
}
}
}
// unused_var(dummy, alwaysfalse);
#pragma unroll
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor_sync
(
~
0
,
maxvalue2
,
mask
));
}
return
__hmax
(
maxvalue2
.
x
,
maxvalue2
.
y
);
}
// each thread block quantize WARP_M * K tile (32 * K)
template
<
bool
fuse_glu
>
struct
quantize_w8a8_act_kernel
{
static
constexpr
bool
check
(
int
M
,
int
K
)
{
const
int
K2
=
fuse_glu
?
K
/
2
:
K
;
return
M
%
WARP_M
==
0
&&
K2
%
WARP_K
==
0
;
}
static
constexpr
dim3
gridSize
(
int
M
,
int
K
)
{
return
dim3
(
M
/
WARP_M
);
}
static
constexpr
dim3
blockSize
(
int
M
,
int
K
)
{
return
dim3
(
NUM_WARPS
*
32
);
}
static
constexpr
size_t
smemSize
(
int
M
,
int
K
)
{
if
constexpr
(
!
fuse_glu
)
{
return
0
;
}
const
int
K2
=
fuse_glu
?
K
/
2
:
K
;
return
INSN_M
*
K2
*
sizeof
(
half_t
);
}
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
// for quantize kernel
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
numWarps
=
blockDim
.
x
/
WARP_SIZE
;
// for GEMM kernel
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
__shared__
alignas
(
128
)
half_t
oscale_shmem
[
WARP_M
];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__
alignas
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
const
int
K2
=
fuse_glu
?
K
/
2
:
K
;
// INSN_M * K2
extern
__shared__
uint8_t
smem
[];
half_t
*
shmem
=
reinterpret_cast
<
half_t
*>
(
smem
);
for
(
int
tileM
=
0
;
tileM
<
WARP_M_TILES
;
tileM
++
)
{
for
(
int
i
=
warpId
;
i
<
INSN_M
;
i
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
+
i
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
half_t
maxv
=
findmax_warp
<
fuse_glu
>
(
input
+
rowGlobal
*
K
,
shmem
+
i
*
K2
,
K
,
alwaysfalse
);
oscale_shmem
[
rowLocal
]
=
maxv
/
half_t
(
127
);
// rscale_shmem[rowLocal] = half_t(127) / maxv;
// maxv_shmem[rowLocal] = maxv;
}
__syncthreads
();
for
(
int
bk
=
warpId
;
bk
<
K2
/
WARP_K
;
bk
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
const
int
col
=
bk
*
WARP_K
;
packed_act_t
tmpout
;
if
constexpr
(
fuse_glu
)
{
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
}
else
{
quantize_w8a8_warp
<
false
>
(
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
}
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
}
__syncthreads
();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
}
};
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
gated_fpsum_warp
result
;
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half_t
&
dst
=
result
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
dst
=
src
.
x
*
gelu_half
(
src
.
y
);
}
}
}
return
result
;
}
static
constexpr
int
unpack_gated_fpsum_shmem_size
=
INSN_M
*
(
WARP_N
/
2
+
8
)
*
sizeof
(
half_t
);
__device__
__forceinline__
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
PACK_SIZE
=
WARP_N
/
2
/
WARP_SIZE
;
using
pack_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
// +8 to prevent bank conflicts
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
/
2
+
8
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
packed_gated_fpsum_t
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
int
row
=
laneId
/
4
;
int
col
=
laneId
%
4
+
j
*
INSN_N
/
2
;
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
0
])
=
fsum
.
data
[
0
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
4
])
=
fsum
.
data
[
2
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
1
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
3
];
}
__syncwarp
();
for
(
int
row
=
0
;
row
<
INSN_M
;
row
++
)
{
pack_t
pack
=
*
reinterpret_cast
<
pack_t
*>
(
&
mat
[
row
][
laneId
*
PACK_SIZE
]);
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
}
__syncwarp
();
}
}
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template
<
typename
Epilogue
>
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogeParams
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
;
// 1
wscale_warp
wscale
;
// 2
psum_warp
psum
;
// 128
for
(
auto
&
pack
:
psum
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
=
0
;
}
}
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale
(
ascales
,
0
,
M
,
ascale
,
true
);
load_wscale
(
wscales
,
0
,
N
,
wscale
,
true
);
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
// load_wscale<false>(wscales, wscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
// if (alwaysfalse) {
// dummy = clock();
// }
compute
(
A
[
k2
],
W
[
k2
],
psum
);
// if (alwaysfalse) {
// dummy = clock();
// }
// asm volatile ("membar.cta;");
}
}
unused_var
(
dummy
,
alwaysfalse
);
f32psum_warp
f32psum
;
#pragma unroll
for
(
int
i
=
0
;
i
<
f32psum
.
size
();
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
f32psum
[
i
].
data
[
j
]
=
0
;
}
}
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
fpsum_warp
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x % 32 == 0) {
// printf("warpId = %d fpsum = %f\n", warpId, (float)fpsum[0].data[0].x);
// }
Epilogue
()(
binfo
,
fpsum
,
M
,
N
,
K
,
epilogeParams
);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
if
(
swapBlockXY
)
{
std
::
swap
(
binfo
.
bm
,
binfo
.
bn
);
std
::
swap
(
binfo
.
numBlocksM
,
binfo
.
numBlocksN
);
}
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
}
};
#if 0
struct EpilogueGLU {
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
gated_fpsum_warp gated_fpsum = apply_glu(fpsum);
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_gated_fpsum_shmem_size, 128) * 128];
unpack_gated_fpsum(gated_fpsum, out + warpId * WARP_M * N / 2, N / 2, shmem[warpId]);
}
};
#endif
};
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/gemm
_w4a4
.h
→
src/kernels/
z
gemm
/zgemm
.h
View file @
e9ad0535
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "common.h"
#include "common.h"
#include "Tensor.h"
#include "Tensor.h"
namespace
nunchaku
::
kernels
{
void
gemm_w4a4
(
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
...
@@ -21,11 +23,15 @@ void gemm_w4a4(
...
@@ -21,11 +23,15 @@ void gemm_w4a4(
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
);
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{});
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{}
,
bool
fuse_glu
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
@@ -33,16 +39,19 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -33,16 +39,19 @@ void gemm_w8a8(Tensor act, // [M, K]
Tensor
wgt
,
// [N, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
ascales
,
// [1, M]
Tensor
wscales
// [1, N]
Tensor
wscales
,
// [1, N]
Tensor
bias
// packed ws [N]
);
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
void
gemm_w8a8_fuse_litela
(
// void gemm_w8a8_fuse_litela(
Tensor
act
,
// [B, (M), K]
// Tensor act, // [B, (M), K]
Tensor
wgt
,
// [N, K]
// Tensor wgt, // [N, K]
Tensor
out_q
,
// [B, (M), N / 3]
// Tensor out_q, // [B, (M), N / 3]
Tensor
out_vk
,
// [B, num_heads, head_dim + 1, head_dim]
// Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor
ascales
,
// [1, M]
// Tensor ascales, // [1, M]
Tensor
wscales
// [1, N]
// Tensor wscales // [1, N]
);
// );
\ No newline at end of file
};
// namespace nunchaku::kernels
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment