Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
981 additions
and
1027 deletions
+981
-1027
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+201
-187
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+31
-33
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+39
-42
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+148
-162
src/kernels/zgemm/lora.cuh
src/kernels/zgemm/lora.cuh
+70
-75
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+137
-151
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+156
-202
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+50
-47
src/layernorm.cpp
src/layernorm.cpp
+21
-12
src/layernorm.h
src/layernorm.h
+18
-10
src/pytorch_compat.h
src/pytorch_compat.h
+94
-91
tests/README.md
tests/README.md
+2
-2
tests/data/__init__.py
tests/data/__init__.py
+0
-1
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+2
-1
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+2
-1
No files found.
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
};
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
};
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
37a27712
...
...
@@ -27,7 +27,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
...
...
@@ -37,8 +37,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
int
attn_tokens
)
{
#ifdef __INTELLISENSE__
static
constexpr
bool
USE_FP4
=
false
;
#endif
...
...
@@ -94,7 +93,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
...
...
@@ -110,11 +111,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
...
...
@@ -130,7 +132,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
...
...
@@ -148,11 +152,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
});
...
...
@@ -171,23 +176,25 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
EpilogueBias
::
Arguments
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>(
{
typename
EpilogueBias
::
Arguments
{
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
},
nextArgs
,
{}
});
{}});
});
});
};
// auto launch_bias = launch;
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
...
...
@@ -196,10 +203,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if
(
rank_up
==
0
)
{
assert
(
rank_down
==
0
);
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
({
midArgs
,
nextArgs
});
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
(
{
midArgs
,
nextArgs
});
}
assert
(
rank_up
%
16
==
0
);
assert
(
lora_up
.
shape
[
0
]
==
N
);
...
...
@@ -218,9 +225,11 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
}
if
(
rank_down
==
0
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
...
...
@@ -229,8 +238,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
},
midArgs
,
nextArgs
,
{}
});
{}});
}
// assert(rank_down == rank_up);
...
...
@@ -246,9 +254,12 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
...
...
@@ -263,8 +274,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
.
alwaysfalse
=
false
,
},
nextArgs
,
{}
});
{}});
// });
};
...
...
@@ -277,28 +287,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
Epilogues
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
...
...
@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
out_vk
.
zero_
();
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
num_blocks_per_batch
=
num_blocks_per_batch
,
.
actualM
=
M
,
},
{});
},
{});
}
else
if
(
rotary_emb
.
valid
())
{
assert
(
norm_q
.
valid
());
...
...
@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
2
]
==
Epilogues
::
EpilogueRMSNormRope
::
HEAD_DIM
);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
// GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
// GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
...
...
@@ -363,27 +375,33 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
};
if
(
out_q
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
(
{
argsRope
,
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
},
{});
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}},
{});
}
else
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
},
{});
}
},
{});
}
}
else
if
(
out
.
valid
())
{
...
...
@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE
=
128
;
}
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
><<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
<<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
checkCUDA
(
cudaGetLastError
());
}
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
...
@@ -475,7 +497,8 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
...
...
@@ -491,8 +514,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
}
);
});
checkCUDA
(
cudaGetLastError
());
});
// });
...
...
@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
...
...
@@ -547,11 +565,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
}
...
...
src/kernels/zgemm/gemm_w4a4_test.cu
View file @
37a27712
...
...
@@ -26,8 +26,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
typename
kernel
::
Arguments
{.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
N
=
N
,
...
...
@@ -38,9 +37,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
}
}
);
}});
checkCUDA
(
cudaGetLastError
());
}
...
...
@@ -79,12 +76,13 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
actualM
=
numTokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
}
);
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}});
checkCUDA
(
cudaGetLastError
());
}
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
37a27712
...
...
@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
assert
(
oscales
.
numel
()
==
M
*
1
);
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
assert
(
kernel
::
check
(
M
,
K
));
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
};
...
...
@@ -50,9 +48,7 @@ void gemm_w8a8(Tensor act, // [M, K]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
)
{
Tensor
bias
)
{
using
GEMM
=
GEMM_W8A8
;
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
...
...
@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
};
...
...
@@ -98,16 +96,15 @@ void gemm_w8a8(Tensor act, // [M, K]
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
// Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
{}
});
{}});
};
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
...
...
src/kernels/zgemm/gemm_w8a8.cuh
View file @
37a27712
...
...
@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using
psum_warp
=
std
::
array
<
packed_psum_t
,
WARP_M_TILES
*
WARP_N_TILES
>
;
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
// packed_psum_t psum;
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
])
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
]));
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
])
);
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
]));
return
psum
;
}
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
]);
}
...
...
@@ -62,11 +66,12 @@ public:
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
template
<
bool
input_shmem
=
false
>
__device__
__forceinline__
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
__device__
__forceinline__
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
QUANTIZE_BITWIDTH
=
8
;
...
...
@@ -75,8 +80,9 @@ public:
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
// a pack is {a0, ..., a7} in figure
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
// INSN_K / 2
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
...
...
@@ -86,7 +92,7 @@ public:
packed_input
packs
[
NUM_PACKWARPS
];
// load
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
...
...
@@ -96,7 +102,7 @@ public:
// quantize
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
const
int
row
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
const
int
col
=
laneId
%
NUM_PACKS_PER_ROW
;
...
...
@@ -104,7 +110,7 @@ public:
float
rscale
=
cuda_frcp
(
float
(
oscales
[
row
]));
uint32_t
qpack
=
0
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2
fval
=
half22float2
(
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]))
*
float2
(
rscale
,
rscale
);
...
...
@@ -126,8 +132,8 @@ public:
* each warp finds absmax from a row
*/
template
<
bool
fuse_glu
=
false
>
__device__
__forceinline__
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
...
...
@@ -136,10 +142,10 @@ public:
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
NUM_STAGES
=
2
;
half2_t
maxvalue2
=
{
0
,
0
};
half2_t
maxvalue2
=
{
0
,
0
};
packed_input
pack
[
NUM_STAGES
];
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
const
int
idx
=
k
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
if
(
idx
<
K
)
{
...
...
@@ -155,7 +161,7 @@ public:
// TODO: store quantized data to shmem (instead of half)
for
(
int
k1
=
0
;
k1
<
ceilDiv
(
K
,
PACK_SIZE
*
WARP_SIZE
);
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
const
int
nextidx
=
(
k1
+
k2
+
NUM_STAGES
-
1
)
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
...
...
@@ -172,7 +178,7 @@ public:
if
constexpr
(
fuse_glu
)
{
packed_gated_input
gated
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
gated
[
j
]
=
p
[
j
].
x
*
gelu_half
(
p
[
j
].
y
);
p
[
j
].
x
=
gated
[
j
];
...
...
@@ -185,7 +191,7 @@ public:
}
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__habs2
(
p
[
j
]));
}
...
...
@@ -194,7 +200,7 @@ public:
// unused_var(dummy, alwaysfalse);
#pragma unroll
#pragma unroll
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor_sync
(
~
0
,
maxvalue2
,
mask
));
}
...
...
@@ -223,8 +229,8 @@ public:
return
INSN_M
*
K2
*
sizeof
(
half_t
);
}
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
// for quantize kernel
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -235,7 +241,6 @@ public:
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
__shared__
alignas
(
128
)
half_t
oscale_shmem
[
WARP_M
];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__
alignas
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
...
...
@@ -267,36 +272,27 @@ public:
packed_act_t
tmpout
;
if
constexpr
(
fuse_glu
)
{
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]);
}
else
{
quantize_w8a8_warp
<
false
>
(
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]);
}
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
}
__syncthreads
();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
}
};
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
gated_fpsum_warp
result
;
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
...
...
@@ -310,10 +306,9 @@ public:
return
result
;
}
static
constexpr
int
unpack_gated_fpsum_shmem_size
=
INSN_M
*
(
WARP_N
/
2
+
8
)
*
sizeof
(
half_t
);
__device__
__forceinline__
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
__device__
__forceinline__
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
PACK_SIZE
=
WARP_N
/
2
/
WARP_SIZE
;
...
...
@@ -345,18 +340,17 @@ public:
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template
<
typename
Epilogue
>
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogeParams
,
bool
alwaysfalse
)
{
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -389,7 +383,7 @@ public:
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
...
...
@@ -421,17 +415,15 @@ public:
f32psum_warp
f32psum
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
f32psum
.
size
();
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
f32psum
[
i
].
data
[
j
]
=
0
;
}
}
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
fpsum_warp
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
...
...
@@ -443,24 +435,21 @@ public:
Epilogue
()(
binfo
,
fpsum
,
M
,
N
,
K
,
epilogeParams
);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
void
operator
()(
const
packed_act_t
*
act
,
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_ascale_t
*
ascales
,
const
packed_wscale_t
*
wscales
,
// half_t *out,
int
M
,
int
N
,
int
K
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
...
...
@@ -476,25 +465,25 @@ public:
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M
,
N
,
K
,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
alwaysfalse
);
}
};
#if 0
struct EpilogueGLU {
struct Arguments { size_t unused; };
...
...
@@ -510,9 +499,6 @@ public:
}
};
#endif
};
};
// namespace nunchaku::kernels
src/kernels/zgemm/lora.cuh
View file @
37a27712
...
...
@@ -2,7 +2,6 @@
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
...
...
@@ -43,8 +42,8 @@ public:
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__
__forceinline__
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
packed_fpsum_t
*
ptr_lane
=
&
ptr
[
rtile
*
LORA_R_TILES
*
WARP_SIZE
+
laneId
];
...
...
@@ -60,15 +59,16 @@ public:
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__
__forceinline__
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
__device__
__forceinline__
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
...
...
@@ -79,8 +79,7 @@ public:
});
}
// no vector reduction in sm_89 :(
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -108,7 +107,6 @@ public:
// });
// }
struct
EpilogueLoraUp
{
struct
Arguments
{
const
float
*
lora_act
;
...
...
@@ -120,8 +118,12 @@ public:
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -132,7 +134,7 @@ public:
int
dummy
=
0
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
const
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
...
@@ -147,7 +149,7 @@ public:
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
packed_f32psum_t
pack
=
A
[
m
*
LORA_R_TILES
+
r
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
pack
.
data
[
j
]
*=
scales
[
rtile
*
LORA_R_TILES
+
r
];
}
...
...
@@ -159,15 +161,15 @@ public:
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
}
}
}
};
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
...
...
@@ -194,25 +196,24 @@ public:
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing
(packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
// the branch splits the basic blocks and prevents the overlap of memory access and computing
//
(packed_fp16_to_fp32)
add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_act
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
...
...
@@ -220,21 +221,20 @@ public:
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
CHECK_NAN
(
fpsum
,
"fpsum"
);
apply_lora_up
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
apply_lora_up
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
scales
,
args
.
rank
,
args
.
alwaysfalse
);
args
.
alwaysfalse
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
}
...
...
@@ -250,8 +250,8 @@ public:
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -259,7 +259,7 @@ public:
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
...
@@ -270,11 +270,11 @@ public:
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
...
...
@@ -294,7 +294,7 @@ public:
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
...
...
@@ -324,38 +324,33 @@ public:
}
}
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
apply_lora_down
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
apply_lora_down
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
rank
,
args
.
alwaysfalse
);
args
.
alwaysfalse
);
}
};
};
};
// namespace nunchaku::kernels
src/kernels/zgemm/mma.cuh
View file @
37a27712
...
...
@@ -7,53 +7,44 @@
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
...
...
@@ -66,40 +57,36 @@ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
));
#else
static_assert
(
!
is_bf16
);
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
...
...
@@ -112,43 +99,39 @@ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
"C"
(
BType
::
value
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
...
...
@@ -171,19 +154,22 @@ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
)
);
"C"
(
BType
::
value
));
#endif
return
d
;
}
};
// namespace nunchaku::kernels
src/kernels/zgemm/mma_earlycuda.cuh
View file @
37a27712
...
...
@@ -6,56 +6,46 @@
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
struct
f32
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
};
struct
f16
{
};
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
};
struct
bf16
{
};
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
};
struct
s32
{
};
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
...
...
@@ -68,64 +58,43 @@ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
"{%7},"
"{tmp0, tmp1};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
#endif
return
d
;
}
template
<
bool
is_bf16
>
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
return
d
;
}
#endif
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
...
...
@@ -138,24 +107,17 @@ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
#endif
return
d
;
}
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
...
...
@@ -166,17 +128,10 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
...
...
@@ -199,21 +154,24 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
return
d
;
}
template
<
>
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
static
constexpr
int
K
=
64
;
...
...
@@ -224,17 +182,10 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
...
...
@@ -257,17 +208,20 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%9},"
"{tmp2, tmp3};
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
)
);
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
return
d
;
}
};
// namespace nunchaku::kernels
src/kernels/zgemm/zgemm.h
View file @
37a27712
...
...
@@ -5,8 +5,7 @@
namespace
nunchaku
::
kernels
{
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
...
...
@@ -24,7 +23,7 @@ void gemm_w4a4(
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
...
...
@@ -34,11 +33,17 @@ void gemm_w4a4(
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
int
attn_tokens
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
@@ -48,7 +53,7 @@ void gemm_w8a8(Tensor act, // [M, K]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
bias
// packed ws [N]
);
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
...
...
@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
);
float
scale
);
// EXPERIMENTAL, for sm_75
void
set_faster_i2f_mode
(
std
::
string
mode
);
...
...
src/layernorm.cpp
View file @
37a27712
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
eps
(
eps
)
{
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
eps
(
eps
)
{
if
(
elementwise_affine
)
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
bias
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
}
registerParams
(
weight
,
"weight"
)
(
bias
,
"bias"
)
;
registerParams
(
weight
,
"weight"
)(
bias
,
"bias"
);
}
Tensor
LayerNorm
::
forward
(
Tensor
x
)
{
...
...
@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
return
out
;
}
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
src/layernorm.h
View file @
37a27712
...
...
@@ -20,9 +20,8 @@ private:
class
RMSNorm
:
public
Module
{
public:
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
{
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
registerParams
(
weight
,
"weight"
);
}
...
...
@@ -36,13 +35,16 @@ public:
class
RMSNormGeneral
{
friend
class
LlamaDecoderLayer
;
public:
RMSNormGeneral
(
int
hidden_size
,
bool
act_sum
,
float
eps
,
bool
use_per_token_quant
,
Device
device
)
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
{
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
{
this
->
weight
=
Tensor
::
ones
({
hidden_size
},
Tensor
::
FP32
,
device
);
}
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
if
(
act_sum
)
{
forward_with_act_sum
(
x
,
quantized_hidden_states_buffer
,
quantized_scale_buffer
,
quantized_sum_buffer
);
}
else
{
...
...
@@ -51,8 +53,14 @@ public:
}
private:
void
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
private:
const
bool
act_sum
;
...
...
src/pytorch_compat.h
View file @
37a27712
...
...
@@ -4,103 +4,106 @@
#include "Tensor.h"
namespace
pytorch_compat
{
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
assert
(
cond
);
}
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
assert
(
cond
);
}
template
<
typename
T
>
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
template
<
typename
T
>
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
}
}
namespace
at
{
using
::
Tensor
;
namespace
at
{
using
::
Tensor
;
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
struct
Generator
{
Generator
()
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
struct
Generator
{
Generator
()
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
std
::
mutex
mutex_
;
};
};
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
struct
StreamWrapper
{
struct
StreamWrapper
{
cudaStream_t
st
;
cudaStream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
cudaStream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
}
struct
CUDAGuard
{
struct
CUDAGuard
{
int
dev
;
};
};
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
return
Generator
();
}
}
}
}
}
// namespace detail
}
// namespace cuda
using
CUDAGeneratorImpl
=
Generator
;
using
CUDAGeneratorImpl
=
Generator
;
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
}
}
// namespace at
namespace
torch
{
using
at
::
kFloat32
;
using
at
::
kFloat
;
using
at
::
kFloat16
;
using
at
::
kBFloat16
;
using
at
::
kInt32
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
namespace
torch
{
using
at
::
kFloat32
;
using
at
::
kFloat
;
using
at
::
kFloat16
;
using
at
::
kBFloat16
;
using
at
::
kInt32
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
using
IntArrayRef
=
std
::
vector
<
int
>
;
using
TensorOptions
=
Tensor
::
TensorOptions
;
using
IntArrayRef
=
std
::
vector
<
int
>
;
using
TensorOptions
=
Tensor
::
TensorOptions
;
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
return
Tensor
::
empty_like
(
tensor
);
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
());
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
()).
zero_
();
}
}
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
}
}
}
// namespace functional
}
// namespace nn
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
int
a
;
int
b
;
};
}
}
namespace
c10
{
using
std
::
optional
;
}
};
}
// namespace indexing
}
// namespace torch
namespace
c10
{
using
std
::
optional
;
}
}
// namespace pytorch_compat
tests/README.md
View file @
37a27712
tests/data/__init__.py
View file @
37a27712
...
...
@@ -3,7 +3,6 @@ import random
import
datasets
import
yaml
from
huggingface_hub
import
snapshot_download
from
nunchaku.utils
import
fetch_or_download
...
...
tests/flux/test_flux_cache.py
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
@@ -8,7 +9,7 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
if
get_precision
()
==
"int4"
else
0.1
44
),
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
if
get_precision
()
==
"int4"
else
0.1
61
),
],
)
def
test_flux_dev_cache
(
...
...
tests/flux/test_flux_dev.py
View file @
37a27712
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
...
...
@@ -9,7 +10,7 @@ from .utils import run_test
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
[
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.139
if
get_precision
()
==
"int4"
else
0.146
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
if
get_precision
()
==
"int4"
else
0.1
33
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
if
get_precision
()
==
"int4"
else
0.1
56
),
],
)
def
test_flux_dev
(
...
...
Prev
1
…
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment