Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
37a27712
Unverified
Commit
37a27712
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
Merge pull request #340 from mit-han-lab/dev
feat: support PuLID, Double FBCache and TeaCache; better linter
parents
c1d6fc84
760ab022
Changes
192
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
981 additions
and
1027 deletions
+981
-1027
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
+2
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+201
-187
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+31
-33
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+39
-42
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+148
-162
src/kernels/zgemm/lora.cuh
src/kernels/zgemm/lora.cuh
+70
-75
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+137
-151
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+156
-202
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+50
-47
src/layernorm.cpp
src/layernorm.cpp
+21
-12
src/layernorm.h
src/layernorm.h
+18
-10
src/pytorch_compat.h
src/pytorch_compat.h
+94
-91
tests/README.md
tests/README.md
+2
-2
tests/data/__init__.py
tests/data/__init__.py
+0
-1
tests/flux/test_flux_cache.py
tests/flux/test_flux_cache.py
+2
-1
tests/flux/test_flux_dev.py
tests/flux/test_flux_dev.py
+2
-1
No files found.
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
};
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
};
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
true
>;
};
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>;
};
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu
View file @
37a27712
#include "gemm_w4a4_launch_impl.cuh"
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16_FasterI2F
,
false
>;
};
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
37a27712
...
@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
...
@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
template
<
>
template
<
>
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>::
gemm_w4a4
(
void
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_FP16
,
false
>::
gemm_w4a4
(
#endif
#endif
Tensor
act
,
// packed act [M, K / 2]
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fuse_silu
,
bool
fp4
,
bool
fp4
,
float
alpha
,
float
alpha
,
Tensor
wcscales
,
// packed ws [N]
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
int
attn_tokens
)
{
)
{
#ifdef __INTELLISENSE__
#ifdef __INTELLISENSE__
static
constexpr
bool
USE_FP4
=
false
;
static
constexpr
bool
USE_FP4
=
false
;
#endif
#endif
...
@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if
constexpr
(
!
USE_FP4
)
{
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
int
,
bool
,
int
,
bool
>
;
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
}
assert
(
alpha
==
1.0
f
);
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
M
,
N
,
K
,
args
,
args
,
swapBlockMN
,
swapBlockMN
,
false
false
);
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
});
});
return
;
return
;
...
@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
const
packed_wmscale_t
*
,
float
,
float
,
int
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
int
,
bool
,
int
,
bool
>
;
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
...
@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
alpha
,
M
,
N
,
K
,
M
,
N
,
K
,
args
,
args
,
swapBlockMN
,
swapBlockMN
,
false
false
);
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
});
});
return
;
return
;
}
}
...
@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
Epilogue
Nop
>
;
typename
EpilogueBias
::
Arguments
{
return
launch
.
template
operator
()
<
Epilogue
>(
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
{
typename
EpilogueBias
::
Arguments
{
.
scale
=
USE_SCALE
?
wcscale
s
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
bias
=
USE_BIAS
?
bia
s
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
}
,
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
nextArgs
,
}
,
{}
nextArgs
,
});
{}
});
});
});
});
});
};
};
// auto launch_bias = launch;
// auto launch_bias = launch;
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
auto
launch_lora
=
[
&
]
<
typename
NextEpilogue
,
typename
MidEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
,
MidEpilogue
::
Arguments
midArgs
)
{
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
const
int
rank_up
=
lora_up
.
valid
()
?
lora_up
.
shape
[
1
]
:
0
;
const
int
rank_up
=
lora_up
.
valid
()
?
lora_up
.
shape
[
1
]
:
0
;
const
int
rank_down
=
lora_down
.
valid
()
?
lora_down
.
shape
[
1
]
:
0
;
const
int
rank_down
=
lora_down
.
valid
()
?
lora_down
.
shape
[
1
]
:
0
;
if
(
rank_up
==
0
)
{
if
(
rank_up
==
0
)
{
assert
(
rank_down
==
0
);
assert
(
rank_down
==
0
);
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
({
midArgs
,
nextArgs
});
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
(
{
midArgs
,
nextArgs
});
}
}
assert
(
rank_up
%
16
==
0
);
assert
(
rank_up
%
16
==
0
);
assert
(
lora_up
.
shape
[
0
]
==
N
);
assert
(
lora_up
.
shape
[
0
]
==
N
);
...
@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
lora_act_in
.
shape
[
0
]
==
M
);
assert
(
lora_act_in
.
shape
[
0
]
==
M
);
assert
(
lora_act_in
.
shape
[
1
]
==
rank_up
);
assert
(
lora_act_in
.
shape
[
1
]
==
rank_up
);
using
LoraUp
=
Lora
;
using
LoraUp
=
Lora
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
scale_t
scales
;
scale_t
scales
;
...
@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
}
}
if
(
rank_down
==
0
)
{
if
(
rank_down
==
0
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
return
launch_bias
.
template
operator
()
<
Epilogue
>({
MidEpilogue
,
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
NextEpilogue
,
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
typename
GEMM
::
EpilogueNop
>
;
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
rank
=
rank_up
,
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
scales
=
scales
,
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
alwaysfalse
=
false
,
.
rank
=
rank_up
,
},
.
scales
=
scales
,
midArgs
,
.
alwaysfalse
=
false
,
nextArgs
,
},
{}
midArgs
,
});
nextArgs
,
{}});
}
}
// assert(rank_down == rank_up);
// assert(rank_down == rank_up);
...
@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
return
launch_bias
.
template
operator
()
<
Epilogue
>({
MidEpilogue
,
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
typename
LoraDown
::
EpilogueLoraDown
,
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
NextEpilogue
,
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
typename
GEMM
::
EpilogueNop
>
;
.
rank
=
rank_up
,
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
scales
=
scales
,
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
alwaysfalse
=
false
,
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
},
.
rank
=
rank_up
,
midArgs
,
.
scales
=
scales
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
alwaysfalse
=
false
,
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
},
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
midArgs
,
.
rank
=
rank_down
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
alwaysfalse
=
false
,
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
},
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
nextArgs
,
.
rank
=
rank_down
,
{}
.
alwaysfalse
=
false
,
});
},
nextArgs
,
{}});
// });
// });
};
};
...
@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
auto
argsQuantize
=
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
typename
EpilogueQuantize
::
Arguments
{.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()};
};
// TODO: check if gelu is needed
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
launch_lora
.
template
typename
GEMM
::
EpilogueDefault
::
Arguments
{
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
.
out
=
out
.
data_ptr
<
half_t
>
(),
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
actualM
=
actualM
,
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualN
=
actualN
,
.
actualM
=
actualM
,
},
.
actualN
=
actualN
,
argsQuantize
},
},
{});
argsQuantize
},
{});
}
else
{
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
Epilogues
::
EpilogueGelu
>(
argsQuantize
,
{});
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
Epilogues
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
}
else
if
(
out_linearattn
.
valid
())
{
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
assert
(
out_vk
.
valid
());
...
@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
out_vk
.
shape
[
3
]
==
Epilogue
::
LITELA_HEAD_DIM
);
assert
(
out_vk
.
shape
[
3
]
==
Epilogue
::
LITELA_HEAD_DIM
);
assert
(
out_vk
.
shape
[
1
]
*
Epilogue
::
LITELA_HEAD_DIM
*
3
==
N
);
assert
(
out_vk
.
shape
[
1
]
*
Epilogue
::
LITELA_HEAD_DIM
*
3
==
N
);
int
batch_size
=
out_vk
.
shape
[
0
];
int
batch_size
=
out_vk
.
shape
[
0
];
int
num_heads
=
out_vk
.
shape
[
1
];
int
num_heads
=
out_vk
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
out_linearattn
.
dtype
()));
assert
(
isTypeMatch
<
half_t
>
(
out_linearattn
.
dtype
()));
assert
(
out_linearattn
.
ndims
()
==
3
);
assert
(
out_linearattn
.
ndims
()
==
3
);
...
@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
out_vk
.
zero_
();
out_vk
.
zero_
();
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
typename
Epilogue
::
Arguments
{
launch_lora
.
template
operator
()
<
Epilogue
,
typename
GEMM
::
EpilogueNop
>(
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
typename
Epilogue
::
Arguments
{
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
out_q
=
out_linearattn
.
data_ptr
<
half_t
>
(),
.
num_blocks_per_batch
=
num_blocks_per_batch
,
.
out_vk
=
out_vk
.
data_ptr
<
float
>
(),
.
actualM
=
M
,
.
num_blocks_per_batch
=
num_blocks_per_batch
,
},
{});
.
actualM
=
M
,
},
{});
}
else
if
(
rotary_emb
.
valid
())
{
}
else
if
(
rotary_emb
.
valid
())
{
assert
(
norm_q
.
valid
());
assert
(
norm_q
.
valid
());
...
@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
2
]
==
Epilogues
::
EpilogueRMSNormRope
::
HEAD_DIM
);
assert
(
rotary_emb
.
shape
[
2
]
==
Epilogues
::
EpilogueRMSNormRope
::
HEAD_DIM
);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
// GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualM = actualM,
// .actualN = actualN,
// .actualN = actualN,
...
@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
...
@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// }, {});
// }, {});
using
EpilogueRope
=
typename
Epilogues
::
EpilogueRMSNormRope
;
using
EpilogueRope
=
typename
Epilogues
::
EpilogueRMSNormRope
;
auto
argsRope
=
typename
Epilogues
::
EpilogueRMSNormRope
::
Arguments
{
auto
argsRope
=
typename
Epilogues
::
EpilogueRMSNormRope
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
.
epsilon
=
1e-6
,
.
epsilon
=
1e-6
,
};
};
if
(
out_q
.
valid
())
{
if
(
out_q
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
launch_lora
.
template
argsRope
,
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
typename
GEMM
::
EpilogueNop
>
(
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
{
argsRope
,
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
actualM
=
attn_tokens
,
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
}
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
},
{});
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}},
{});
}
else
{
}
else
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueNop
>
({
launch_lora
argsRope
,
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpilogueDefault
>,
typename
GEMM
::
EpilogueDefault
::
Arguments
{
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
.
out
=
out
.
data_ptr
<
half_t
>
(),
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
actualM
=
actualM
,
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualN
=
actualN
,
.
actualM
=
actualM
,
}
.
actualN
=
actualN
,
},
{});
}},
{});
}
}
}
else
if
(
out
.
valid
())
{
}
else
if
(
out
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueDefault
;
using
Epilogue
=
typename
GEMM
::
EpilogueDefault
;
typename
Epilogue
::
Arguments
args
{
typename
Epilogue
::
Arguments
args
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
actualN
=
actualN
,
};
};
...
@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
...
@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
using
Epilogue
=
typename
Epilogues
::
EpilogueLiteLA
;
using
Epilogue
=
typename
Epilogues
::
EpilogueLiteLA
;
int
batch_size
=
vk
.
shape
[
0
];
int
batch_size
=
vk
.
shape
[
0
];
int
num_heads
=
vk
.
shape
[
1
];
int
num_heads
=
vk
.
shape
[
1
];
int
num_tokens
=
q
.
shape
[
1
];
int
num_tokens
=
q
.
shape
[
1
];
assert
(
isTypeMatch
<
half_t
>
(
q
.
scalar_type
()));
assert
(
isTypeMatch
<
half_t
>
(
q
.
scalar_type
()));
...
@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
...
@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE
=
128
;
BLOCK_SIZE
=
128
;
}
}
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
><<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
invoke_kernel
<
typename
Epilogue
::
vk_mul_q_kernel
>
q
.
data_ptr
<
half_t
>
(),
<<<
dim3
(
ceilDiv
(
num_tokens
,
BLOCK_SIZE
),
num_heads
,
batch_size
),
BLOCK_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
vk
.
data_ptr
<
float
>
(),
q
.
data_ptr
<
half_t
>
(),
vk
.
data_ptr
<
float
>
(),
1e-6
f
,
num_tokens
);
1e-6
f
,
num_tokens
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
}
}
template
<
typename
Config
,
bool
USE_FP4
>
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
...
@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_rank
=
rank
,
.
lora_rank
=
rank
,
.
M
=
M
,
.
M
=
M
,
.
N
=
N
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
.
alwaysfalse
=
false
,
}
});
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
});
});
// });
// });
...
@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
...
@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
template
<
typename
Config
,
bool
USE_FP4
>
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
if
constexpr
(
USE_FP4
)
{
if
constexpr
(
USE_FP4
)
{
assert
(
false
);
// not implemented
assert
(
false
);
// not implemented
return
;
return
;
}
}
...
@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
...
@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
dim3
grid
(
M
/
GEMM
::
WARP_M
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_act_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
output
.
data_ptr
<
packed_act_t
>
(),
oscales
.
data_ptr
<
packed_ascale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
}
}
...
@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
...
@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert
(
output
.
ndims
()
==
2
);
assert
(
output
.
ndims
()
==
2
);
assert
(
output
.
shape
[
0
]
==
N
);
assert
(
output
.
shape
[
0
]
==
N
);
assert
(
output
.
shape
[
1
]
==
K
/
2
);
assert
(
output
.
shape
[
1
]
==
K
/
2
);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
// assert(oscales.dtype() == Tensor::FP16);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
assert
(
oscales
.
numel
()
==
N
*
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
dim3
grid
(
N
/
GEMM
::
WARP_N
,
K
/
GEMM
::
WARP_K
);
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
invoke_kernel
<
typename
GEMM
::
quantize_w4a4_wgt_kernel
><<<
grid
,
GEMM
::
WARP_SIZE
,
0
,
getCurrentCUDAStream
()
>>>
(
input
.
data_ptr
<
half_t
>
(),
input
.
data_ptr
<
half_t
>
(),
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
output
.
data_ptr
<
packed_wgt_t
>
(),
oscales
.
data_ptr
<
packed_wscale_t
>
(),
K
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_test.cu
View file @
37a27712
...
@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
...
@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpilogueRMSNormRope
;
using
Epilogue
=
GEMM
::
EpilogueRMSNormRope
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
...
@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
typename
kernel
::
Arguments
{.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
M
=
M
,
.
N
=
N
,
.
N
=
N
,
.
actualM
=
M
,
.
actualM
=
M
,
.
actualN
=
N
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
Epilogue
::
packed_rotemb_t
>
(),
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
Epilogue
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
GEMM
::
half_t
>
(),
.
epsilon
=
1e-6
,
.
epsilon
=
1e-6
,
}});
}
}
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
}
}
...
@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
...
@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor
output
=
Tensor
::
empty_like
(
input
);
Tensor
output
=
Tensor
::
empty_like
(
input
);
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpiloguePackQKV
;
using
Epilogue
=
GEMM
::
EpiloguePackQKV
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
...
@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
input
=
input
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
output
=
output
.
data_ptr
<
GEMM
::
half_t
>
(),
.
M
=
M
,
.
M
=
M
,
.
N
=
N
,
.
N
=
N
,
.
actualM
=
M
,
.
actualM
=
M
,
.
actualN
=
N
,
.
actualN
=
N
,
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
argsEpilogue
=
typename
Epilogue
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogue
::
packed_qkv_t
>
(),
.
actualM
=
numTokens
,
.
actualM
=
numTokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_q
=
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
}
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
.
strideHead_v
=
);
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
}});
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w8a8.cu
View file @
37a27712
...
@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
...
@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
assert
(
oscales
.
numel
()
==
M
*
1
);
assert
(
oscales
.
numel
()
==
M
*
1
);
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
auto
launch
=
[
&
]
<
bool
FUSE_GLU
>
()
{
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
using
kernel
=
GEMM
::
quantize_w8a8_act_kernel
<
FUSE_GLU
>
;
assert
(
kernel
::
check
(
M
,
K
));
assert
(
kernel
::
check
(
M
,
K
));
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
grid
=
kernel
::
gridSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
dim3
block
=
kernel
::
blockSize
(
M
,
K
);
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
GEMM
::
half_t
*
,
GEMM
::
packed_act_t
*
,
GEMM
::
packed_ascale_t
*
,
int
,
bool
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
92160
));
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
func
<<<
grid
,
block
,
kernel
::
smemSize
(
M
,
K
)
>>>
(
input
.
data_ptr
<
GEMM
::
half_t
>
(),
input
.
data_ptr
<
GEMM
::
half_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
output
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
oscales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
K
,
K
,
false
);
false
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
};
};
...
@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
...
@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
}
}
}
}
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
wscales
,
// [1, N]
Tensor
bias
Tensor
bias
)
{
)
{
using
GEMM
=
GEMM_W8A8
;
using
GEMM
=
GEMM_W8A8
;
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
...
@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
std
::
swap
(
grid
.
x
,
grid
.
y
);
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
}
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>><<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
invoke_kernel
<
GEMM
::
gemm_w8a8_kernel
<
Epilogue
>>
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
ascales
.
data_ptr
<
GEMM
::
packed_ascale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
wscales
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M
,
N
,
K
,
args
,
M
,
swapBlockMN
,
N
,
false
K
,
);
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
checkCUDA
(
cudaGetLastError
());
};
};
...
@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K]
assert
(
bias
.
numel
()
==
N
);
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
// Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
GEMM
::
EpilogueBias
<
true
,
false
>::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
},
nextArgs
,
nextArgs
,
{}});
{}
});
};
};
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
launch_bias
.
template
operator
()
<
GEMM
::
EpilogueDefault
>(
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
out
=
out
.
data_ptr
<
GEMM
::
half_t
>
(),
.
actualM
=
actualM
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
actualN
=
actualN
,
});
});
...
@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela(
...
@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela(
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
// GEMM::half_t *,
...
@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela(
...
@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela(
ascales.data_ptr<GEMM::packed_ascale_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
// nullptr,
M, N, K, epilogueArgs,
M, N, K, epilogueArgs,
swapBlockMN,
swapBlockMN,
false
false
);
);
...
@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela(
...
@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela(
}
}
#endif
#endif
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w8a8.cuh
View file @
37a27712
...
@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
...
@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
public:
using
psum_warp
=
std
::
array
<
packed_psum_t
,
WARP_M_TILES
*
WARP_N_TILES
>
;
using
psum_warp
=
std
::
array
<
packed_psum_t
,
WARP_M_TILES
*
WARP_N_TILES
>
;
__device__
__forceinline__
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
// packed_psum_t psum;
// packed_psum_t psum;
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
:
"r"
(
act
.
x
),
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
"r"
(
act
.
y
),
:
"r"
(
act
.
z
),
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"r"
(
wgt
.
x
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
wgt
.
y
),
"r"
(
psum
.
data
[
0
]),
"r"
(
psum
.
data
[
1
]),
"r"
(
psum
.
data
[
2
]),
"r"
(
psum
.
data
[
3
])
// "r"(0), "r"(0), "r"(0), "r"(0)
);
"r"
(
psum
.
data
[
0
]),
asm
volatile
(
"r"
(
psum
.
data
[
1
]),
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"r"
(
psum
.
data
[
2
]),
"{%0, %1, %2, %3},"
"r"
(
psum
.
data
[
3
]));
"{%4, %5, %6, %7},"
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%8, %9},"
"{%0, %1, %2, %3},"
"{%10, %11, %12, %13};
\n
"
"{%4, %5, %6, %7},"
:
"{%8, %9},"
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
"{%10, %11, %12, %13};
\n
"
:
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
:
"r"
(
act
.
x
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"r"
(
act
.
y
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
act
.
z
),
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
])
"r"
(
act
.
w
),
);
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
"r"
(
psum
.
data
[
5
]),
"r"
(
psum
.
data
[
6
]),
"r"
(
psum
.
data
[
7
]));
return
psum
;
return
psum
;
}
}
__device__
__forceinline__
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
]);
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
]);
}
}
...
@@ -62,11 +66,12 @@ public:
...
@@ -62,11 +66,12 @@ public:
* oscales is per-warp (in shared memory)
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
*/
template
<
bool
input_shmem
=
false
>
template
<
bool
input_shmem
=
false
>
__device__
__forceinline__
__device__
__forceinline__
static
void
static
void
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
quantize_w8a8_warp
(
const
half_t
*
input
,
const
half_t
*
oscales
,
int
stride
,
packed_act_t
&
output
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
QUANTIZE_BITWIDTH
=
8
;
constexpr
int
QUANTIZE_BITWIDTH
=
8
;
...
@@ -75,28 +80,29 @@ public:
...
@@ -75,28 +80,29 @@ public:
// 1 lane = 1 pack
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// a pack is {a0, ..., a7} in figure
// PACK_SIZE * 4 = INSN_K / 2
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
// INSN_K / 2
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
PACK_SIZE
=
INSN_K
/
8
;
// = 4 for 8bit
constexpr
int
NUM_PACKS_PER_ROW
=
INSN_K
/
PACK_SIZE
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
constexpr
int
NUM_ROWS_PER_PACKWARP
=
PACK_SIZE
*
WARP_SIZE
/
INSN_K
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
constexpr
int
NUM_PACKWARPS
=
INSN_M
/
NUM_ROWS_PER_PACKWARP
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_input
packs
[
NUM_PACKWARPS
];
packed_input
packs
[
NUM_PACKWARPS
];
// load
// load
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
rowId
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
int
colId
=
laneId
%
NUM_PACKS_PER_ROW
*
PACK_SIZE
;
packs
[
i
]
=
load
<
input_shmem
>
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
packs
[
i
]
=
load
<
input_shmem
>
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
rowId
*
stride
+
colId
));
}
}
// quantize
// quantize
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
using
matrix_t
=
uint32_t
[
INSN_M
][
NUM_PACKS_PER_ROW
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
const
int
row
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
const
int
row
=
i
*
NUM_ROWS_PER_PACKWARP
+
laneId
/
NUM_PACKS_PER_ROW
;
const
int
col
=
laneId
%
NUM_PACKS_PER_ROW
;
const
int
col
=
laneId
%
NUM_PACKS_PER_ROW
;
...
@@ -104,7 +110,7 @@ public:
...
@@ -104,7 +110,7 @@ public:
float
rscale
=
cuda_frcp
(
float
(
oscales
[
row
]));
float
rscale
=
cuda_frcp
(
float
(
oscales
[
row
]));
uint32_t
qpack
=
0
;
uint32_t
qpack
=
0
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
+=
2
)
{
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2
fval
=
half22float2
(
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]))
*
float2
(
rscale
,
rscale
);
float2
fval
=
half22float2
(
half2_t
(
packs
[
i
][
j
],
packs
[
i
][
j
+
1
]))
*
float2
(
rscale
,
rscale
);
...
@@ -113,7 +119,7 @@ public:
...
@@ -113,7 +119,7 @@ public:
mat
[
row
][
col
]
=
qpack
;
mat
[
row
][
col
]
=
qpack
;
}
}
__syncwarp
();
__syncwarp
();
// convert to imma format
// convert to imma format
int
row
=
laneId
%
16
;
int
row
=
laneId
%
16
;
int
col
=
laneId
/
16
*
4
;
int
col
=
laneId
/
16
*
4
;
...
@@ -126,20 +132,20 @@ public:
...
@@ -126,20 +132,20 @@ public:
* each warp finds absmax from a row
* each warp finds absmax from a row
*/
*/
template
<
bool
fuse_glu
=
false
>
template
<
bool
fuse_glu
=
false
>
__device__
__forceinline__
__device__
__forceinline__
static
half_t
static
half_t
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
findmax_warp
(
const
half_t
*
input
,
half_t
*
output_shmem
,
int
K
,
bool
alwaysfalse
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
using
packed_input
=
std
::
array
<
half2_t
,
4
>
;
using
packed_gated_input
=
std
::
array
<
half_t
,
4
>
;
using
packed_gated_input
=
std
::
array
<
half_t
,
4
>
;
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
PACK_SIZE
=
sizeof
(
packed_input
)
/
sizeof
(
half_t
);
constexpr
int
NUM_STAGES
=
2
;
constexpr
int
NUM_STAGES
=
2
;
half2_t
maxvalue2
=
{
0
,
0
};
half2_t
maxvalue2
=
{
0
,
0
};
packed_input
pack
[
NUM_STAGES
];
packed_input
pack
[
NUM_STAGES
];
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
const
int
idx
=
k
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
const
int
idx
=
k
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
if
(
idx
<
K
)
{
if
(
idx
<
K
)
{
...
@@ -155,11 +161,11 @@ public:
...
@@ -155,11 +161,11 @@ public:
// TODO: store quantized data to shmem (instead of half)
// TODO: store quantized data to shmem (instead of half)
for
(
int
k1
=
0
;
k1
<
ceilDiv
(
K
,
PACK_SIZE
*
WARP_SIZE
);
k1
+=
NUM_STAGES
)
{
for
(
int
k1
=
0
;
k1
<
ceilDiv
(
K
,
PACK_SIZE
*
WARP_SIZE
);
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
const
int
nextidx
=
(
k1
+
k2
+
NUM_STAGES
-
1
)
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
const
int
nextidx
=
(
k1
+
k2
+
NUM_STAGES
-
1
)
*
PACK_SIZE
*
WARP_SIZE
+
laneId
*
PACK_SIZE
;
const
int
nextk2
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
const
int
nextk2
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
if
(
nextidx
<
K
)
{
if
(
nextidx
<
K
)
{
pack
[
nextk2
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
&
input
[
nextidx
]));
pack
[
nextk2
]
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
&
input
[
nextidx
]));
...
@@ -172,11 +178,11 @@ public:
...
@@ -172,11 +178,11 @@ public:
if
constexpr
(
fuse_glu
)
{
if
constexpr
(
fuse_glu
)
{
packed_gated_input
gated
;
packed_gated_input
gated
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
gated
[
j
]
=
p
[
j
].
x
*
gelu_half
(
p
[
j
].
y
);
gated
[
j
]
=
p
[
j
].
x
*
gelu_half
(
p
[
j
].
y
);
p
[
j
].
x
=
gated
[
j
];
p
[
j
].
x
=
gated
[
j
];
p
[
j
].
y
=
0
;
p
[
j
].
y
=
0
;
}
}
int
idx
=
(
k1
+
k2
)
*
PACK_SIZE
/
2
*
WARP_SIZE
+
laneId
*
PACK_SIZE
/
2
;
int
idx
=
(
k1
+
k2
)
*
PACK_SIZE
/
2
*
WARP_SIZE
+
laneId
*
PACK_SIZE
/
2
;
...
@@ -185,7 +191,7 @@ public:
...
@@ -185,7 +191,7 @@ public:
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
p
.
size
();
j
++
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__habs2
(
p
[
j
]));
maxvalue2
=
__hmax2
(
maxvalue2
,
__habs2
(
p
[
j
]));
}
}
...
@@ -194,7 +200,7 @@ public:
...
@@ -194,7 +200,7 @@ public:
// unused_var(dummy, alwaysfalse);
// unused_var(dummy, alwaysfalse);
#pragma unroll
#pragma unroll
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
32
/
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor_sync
(
~
0
,
maxvalue2
,
mask
));
maxvalue2
=
__hmax2
(
maxvalue2
,
__shfl_xor_sync
(
~
0
,
maxvalue2
,
mask
));
}
}
...
@@ -223,8 +229,8 @@ public:
...
@@ -223,8 +229,8 @@ public:
return
INSN_M
*
K2
*
sizeof
(
half_t
);
return
INSN_M
*
K2
*
sizeof
(
half_t
);
}
}
__device__
__device__
void
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
,
bool
alwaysfalse
)
{
// for quantize kernel
// for quantize kernel
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -232,10 +238,9 @@ public:
...
@@ -232,10 +238,9 @@ public:
const
int
numWarps
=
blockDim
.
x
/
WARP_SIZE
;
const
int
numWarps
=
blockDim
.
x
/
WARP_SIZE
;
// for GEMM kernel
// for GEMM kernel
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
bm
=
blockIdx
.
x
/
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
const
int
gemmWarpId
=
blockIdx
.
x
%
(
BLOCK_M
/
WARP_M
);
__shared__
alignas
(
128
)
half_t
oscale_shmem
[
WARP_M
];
__shared__
alignas
(
128
)
half_t
oscale_shmem
[
WARP_M
];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__
alignas
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
__shared__
alignas
(
128
)
uint8_t
tmp_shmem
[
NUM_WARPS
][
512
];
...
@@ -249,7 +254,7 @@ public:
...
@@ -249,7 +254,7 @@ public:
for
(
int
tileM
=
0
;
tileM
<
WARP_M_TILES
;
tileM
++
)
{
for
(
int
tileM
=
0
;
tileM
<
WARP_M_TILES
;
tileM
++
)
{
for
(
int
i
=
warpId
;
i
<
INSN_M
;
i
+=
numWarps
)
{
for
(
int
i
=
warpId
;
i
<
INSN_M
;
i
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
+
i
;
const
int
rowLocal
=
tileM
*
INSN_M
+
i
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
half_t
maxv
=
findmax_warp
<
fuse_glu
>
(
input
+
rowGlobal
*
K
,
shmem
+
i
*
K2
,
K
,
alwaysfalse
);
half_t
maxv
=
findmax_warp
<
fuse_glu
>
(
input
+
rowGlobal
*
K
,
shmem
+
i
*
K2
,
K
,
alwaysfalse
);
...
@@ -260,76 +265,66 @@ public:
...
@@ -260,76 +265,66 @@ public:
__syncthreads
();
__syncthreads
();
for
(
int
bk
=
warpId
;
bk
<
K2
/
WARP_K
;
bk
+=
numWarps
)
{
for
(
int
bk
=
warpId
;
bk
<
K2
/
WARP_K
;
bk
+=
numWarps
)
{
const
int
rowLocal
=
tileM
*
INSN_M
;
const
int
rowLocal
=
tileM
*
INSN_M
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
const
int
rowGlobal
=
blockIdx
.
x
*
WARP_M
+
rowLocal
;
const
int
col
=
bk
*
WARP_K
;
const
int
col
=
bk
*
WARP_K
;
packed_act_t
tmpout
;
packed_act_t
tmpout
;
if
constexpr
(
fuse_glu
)
{
if
constexpr
(
fuse_glu
)
{
quantize_w8a8_warp
<
true
>
(
quantize_w8a8_warp
<
true
>
(
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]);
shmem
+
col
,
oscale_shmem
+
rowLocal
,
K2
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
}
else
{
}
else
{
quantize_w8a8_warp
<
false
>
(
quantize_w8a8_warp
<
false
>
(
input
+
rowGlobal
*
K
+
col
,
input
+
rowGlobal
*
K
+
col
,
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]);
oscale_shmem
+
rowLocal
,
K
,
tmpout
,
&
tmp_shmem
[
warpId
]
);
}
}
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
store
(
&
output
[(((
bm
*
K2
/
WARP_K
+
bk
)
*
NUM_WARPS
+
gemmWarpId
)
*
WARP_M_TILES
+
tileM
)
*
WARP_SIZE
+
laneId
],
tmpout
);
}
}
__syncthreads
();
__syncthreads
();
}
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
pack_ascales
(
oscale_shmem
,
&
oscales
[(
bm
*
NUM_WARPS
+
gemmWarpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
}
}
};
};
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
__device__
__forceinline__
static
gated_fpsum_warp
apply_glu
(
fpsum_warp
fpsum
)
{
gated_fpsum_warp
result
;
gated_fpsum_warp
result
;
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half_t
&
dst
=
result
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
half_t
&
dst
=
result
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
half2_t
src
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
dst
=
src
.
x
*
gelu_half
(
src
.
y
);
dst
=
src
.
x
*
gelu_half
(
src
.
y
);
}
}
}
}
}
}
return
result
;
return
result
;
}
}
static
constexpr
int
unpack_gated_fpsum_shmem_size
=
INSN_M
*
(
WARP_N
/
2
+
8
)
*
sizeof
(
half_t
);
static
constexpr
int
unpack_gated_fpsum_shmem_size
=
INSN_M
*
(
WARP_N
/
2
+
8
)
*
sizeof
(
half_t
);
__device__
__forceinline__
__device__
__forceinline__
static
void
static
void
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
unpack_gated_fpsum
(
gated_fpsum_warp
fpsum
,
half_t
*
output
,
int
stride
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
constexpr
int
PACK_SIZE
=
WARP_N
/
2
/
WARP_SIZE
;
constexpr
int
PACK_SIZE
=
WARP_N
/
2
/
WARP_SIZE
;
using
pack_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
using
pack_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
// +8 to prevent bank conflicts
// +8 to prevent bank conflicts
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
/
2
+
8
];
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
/
2
+
8
];
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
packed_gated_fpsum_t
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
packed_gated_fpsum_t
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
int
row
=
laneId
/
4
;
int
row
=
laneId
/
4
;
int
col
=
laneId
%
4
+
j
*
INSN_N
/
2
;
int
col
=
laneId
%
4
+
j
*
INSN_N
/
2
;
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
0
])
=
fsum
.
data
[
0
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
0
])
=
fsum
.
data
[
0
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
4
])
=
fsum
.
data
[
2
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
][
col
+
4
])
=
fsum
.
data
[
2
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
1
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
1
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
3
];
*
reinterpret_cast
<
half_t
*>
(
&
mat
[
row
+
8
][
col
+
4
])
=
fsum
.
data
[
3
];
}
}
...
@@ -345,28 +340,27 @@ public:
...
@@ -345,28 +340,27 @@ public:
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template
<
typename
Epilogue
>
template
<
typename
Epilogue
>
__device__
__forceinline__
__device__
__forceinline__
static
void
gemm_w8a8_block
(
const
BlockInfo
binfo
,
static
void
gemm_w8a8_block
(
const
packed_act_t
*
act
,
const
BlockInfo
binfo
,
const
packed_wgt_t
*
wgt
,
const
packed_act_t
*
act
,
const
packed_ascale_t
*
ascales
,
const
packed_wgt_t
*
wgt
,
const
packed_wscale_t
*
wscales
,
const
packed_ascale_t
*
ascales
,
// half_t *out,
const
packed_wscale_t
*
wscales
,
int
M
,
// half_t *out,
int
N
,
int
M
,
int
N
,
int
K
,
int
K
,
Epilogue
::
Arguments
epilogeParams
,
Epilogue
::
Arguments
epilogeParams
,
bool
alwaysfalse
)
bool
alwaysfalse
)
{
{
constexpr
int
NUM_STAGES
=
2
;
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
;
// 1
ascale_warp
ascale
;
// 1
wscale_warp
wscale
;
// 2
wscale_warp
wscale
;
// 2
psum_warp
psum
;
// 128
psum_warp
psum
;
// 128
for
(
auto
&
pack
:
psum
)
{
for
(
auto
&
pack
:
psum
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
...
@@ -377,7 +371,7 @@ public:
...
@@ -377,7 +371,7 @@ public:
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale
(
ascales
,
0
,
M
,
ascale
,
true
);
load_ascale
(
ascales
,
0
,
M
,
ascale
,
true
);
load_wscale
(
wscales
,
0
,
N
,
wscale
,
true
);
load_wscale
(
wscales
,
0
,
N
,
wscale
,
true
);
...
@@ -385,14 +379,14 @@ public:
...
@@ -385,14 +379,14 @@ public:
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
}
}
int
dummy
=
0
;
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
...
@@ -421,17 +415,15 @@ public:
...
@@ -421,17 +415,15 @@ public:
f32psum_warp
f32psum
;
f32psum_warp
f32psum
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
f32psum
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
f32psum
.
size
();
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
f32psum
[
i
].
data
[
j
]
=
0
;
f32psum
[
i
].
data
[
j
]
=
0
;
}
}
}
}
apply_scales
([
&
](
int
i
,
int
j
)
{
apply_scales
([
&
](
int
i
,
int
j
)
{
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
return
psum
[
i
*
WARP_N_TILES
+
j
];
},
ascale
,
wscale
,
f32psum
);
fpsum_warp
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
fpsum_warp
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
...
@@ -443,27 +435,24 @@ public:
...
@@ -443,27 +435,24 @@ public:
Epilogue
()(
binfo
,
fpsum
,
M
,
N
,
K
,
epilogeParams
);
Epilogue
()(
binfo
,
fpsum
,
M
,
N
,
K
,
epilogeParams
);
}
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
__device__
void
operator
()(
const
packed_act_t
*
act
,
void
operator
()(
const
packed_wgt_t
*
wgt
,
const
packed_act_t
*
act
,
const
packed_ascale_t
*
ascales
,
const
packed_wgt_t
*
wgt
,
const
packed_wscale_t
*
wscales
,
const
packed_ascale_t
*
ascales
,
// half_t *out,
const
packed_wscale_t
*
wscales
,
int
M
,
// half_t *out,
int
N
,
int
M
,
int
N
,
int
K
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
swapBlockXY
,
bool
alwaysfalse
)
bool
alwaysfalse
)
{
{
BlockInfo
binfo
=
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
};
...
@@ -476,25 +465,25 @@ public:
...
@@ -476,25 +465,25 @@ public:
const
int
bm
=
binfo
.
bm
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
gemm_w8a8_block
<
Epilogue
>
(
gemm_w8a8_block
<
Epilogue
>
(
binfo
,
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ascales
+
bm
*
(
1
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
// only 1 group in W8A8
ASCALES_VALID_LANES
,
// only 1 group in W8A8
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
wscales
+
bn
*
(
1
)
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
// #if 1
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
// #endif
M
,
N
,
K
,
M
,
epilogueArgs
,
N
,
alwaysfalse
K
,
);
epilogueArgs
,
alwaysfalse
);
}
}
};
};
#if 0
#if 0
struct EpilogueGLU {
struct EpilogueGLU {
struct Arguments { size_t unused; };
struct Arguments { size_t unused; };
...
@@ -510,9 +499,6 @@ public:
...
@@ -510,9 +499,6 @@ public:
}
}
};
};
#endif
#endif
};
};
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/lora.cuh
View file @
37a27712
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include "gemm_base.cuh"
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
template
<
typename
Config
>
...
@@ -21,7 +20,7 @@ public:
...
@@ -21,7 +20,7 @@ public:
public:
public:
static
constexpr
int
MAX_RANK
=
1024
;
static
constexpr
int
MAX_RANK
=
1024
;
static
constexpr
int
WARP_R
=
16
;
static
constexpr
int
WARP_R
=
16
;
// static constexpr int LORA_RANK = rank;
// static constexpr int LORA_RANK = rank;
static
constexpr
int
LORA_M_TILES
=
WARP_M
/
16
;
static
constexpr
int
LORA_M_TILES
=
WARP_M
/
16
;
...
@@ -30,57 +29,57 @@ public:
...
@@ -30,57 +29,57 @@ public:
static_assert
(
LORA_M_TILES
==
WARP_M_TILES
);
static_assert
(
LORA_M_TILES
==
WARP_M_TILES
);
static_assert
(
LORA_N_TILES
==
WARP_N_TILES
);
static_assert
(
LORA_N_TILES
==
WARP_N_TILES
);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using
lora_act_warp
=
std
::
array
<
packed_f32psum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act_warp
=
std
::
array
<
packed_f32psum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act16_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act16_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
scale_t
=
std
::
array
<
float
,
MAX_RANK
/
16
>
;
using
scale_t
=
std
::
array
<
float
,
MAX_RANK
/
16
>
;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
// [N / 16, rank / 16, WARP_SIZE]
__device__
__forceinline__
__device__
__forceinline__
static
void
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
packed_fpsum_t
*
ptr_lane
=
&
ptr
[
rtile
*
LORA_R_TILES
*
WARP_SIZE
+
laneId
];
const
packed_fpsum_t
*
ptr_lane
=
&
ptr
[
rtile
*
LORA_R_TILES
*
WARP_SIZE
+
laneId
];
const
int
stride_ntile
=
rank
/
16
*
WARP_SIZE
;
const
int
stride_ntile
=
rank
/
16
*
WARP_SIZE
;
unrolled_loop
<
LORA_N_TILES
>
([
&
]
<
int
n
>
()
{
unrolled_loop
<
LORA_N_TILES
>
([
&
]
<
int
n
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
()
{
constexpr
int
roffset
=
r
*
WARP_SIZE
;
constexpr
int
roffset
=
r
*
WARP_SIZE
;
const
int
noffset
=
n
*
stride_ntile
;
const
int
noffset
=
n
*
stride_ntile
;
result
[
n
*
LORA_R_TILES
+
r
]
=
load_pred
(
ptr_lane
+
noffset
+
roffset
,
pred
);
result
[
n
*
LORA_R_TILES
+
r
]
=
load_pred
(
ptr_lane
+
noffset
+
roffset
,
pred
);
});
});
});
});
}
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__
__forceinline__
__device__
__forceinline__
static
void
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
result
[
i
].
data
[
j
]
=
load_pred
(
ptrlane
+
offset
,
pred
);
// * scales[rtile * LORA_R_TILES + r];
result
[
i
].
data
[
j
]
=
load_pred
(
ptrlane
+
offset
,
pred
);
// * scales[rtile * LORA_R_TILES + r];
});
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
});
});
}
}
// no vector reduction in sm_89 :(
// no vector reduction in sm_89 :(
__device__
__forceinline__
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -108,7 +107,6 @@ public:
...
@@ -108,7 +107,6 @@ public:
// });
// });
// }
// }
struct
EpilogueLoraUp
{
struct
EpilogueLoraUp
{
struct
Arguments
{
struct
Arguments
{
const
float
*
lora_act
;
const
float
*
lora_act
;
...
@@ -120,19 +118,23 @@ public:
...
@@ -120,19 +118,23 @@ public:
bool
alwaysfalse
;
bool
alwaysfalse
;
};
};
__device__
__forceinline__
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_act_warp
lora_act
[
NUM_STAGES
];
// 32
lora_act_warp
lora_act
[
NUM_STAGES
];
// 32
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
int
dummy
=
0
;
int
dummy
=
0
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
// we have rank > 0
const
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
const
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
@@ -140,14 +142,14 @@ public:
...
@@ -140,14 +142,14 @@ public:
load_lora_wgt
(
wgt
,
0
,
rank
,
lora_wgt
[
k
],
pred
);
load_lora_wgt
(
wgt
,
0
,
rank
,
lora_wgt
[
k
],
pred
);
}
}
f32psum_warp
f32psum
=
packed_fp16_to_fp32
(
fpsum
);
// 128
f32psum_warp
f32psum
=
packed_fp16_to_fp32
(
fpsum
);
// 128
auto
compute
=
[
&
scales
](
lora_act_warp
A
,
lora_wgt_warp
W
,
f32psum_warp
&
f32psum
,
int
rtile
)
ALWAYSINLINE
{
auto
compute
=
[
&
scales
](
lora_act_warp
A
,
lora_wgt_warp
W
,
f32psum_warp
&
f32psum
,
int
rtile
)
ALWAYSINLINE
{
lora_act16_warp
A_fp16
;
lora_act16_warp
A_fp16
;
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
packed_f32psum_t
pack
=
A
[
m
*
LORA_R_TILES
+
r
];
packed_f32psum_t
pack
=
A
[
m
*
LORA_R_TILES
+
r
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
pack
.
data
[
j
]
*=
scales
[
rtile
*
LORA_R_TILES
+
r
];
pack
.
data
[
j
]
*=
scales
[
rtile
*
LORA_R_TILES
+
r
];
}
}
...
@@ -159,28 +161,28 @@ public:
...
@@ -159,28 +161,28 @@ public:
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
}
}
}
}
}
}
};
};
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
break
;
}
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
if
(
alwaysfalse
)
{
act
+=
kernels
::
bit_cast
<
int
>
(
lora_act
[
k2
][
0
].
data
[
0
]);
act
+=
kernels
::
bit_cast
<
int
>
(
lora_act
[
k2
][
0
].
data
[
0
]);
}
}
if
(
alwaysfalse
)
{
if
(
alwaysfalse
)
{
dummy
=
clock
();
dummy
=
clock
();
}
}
...
@@ -194,25 +196,24 @@ public:
...
@@ -194,25 +196,24 @@ public:
// NVCC does not know rank > 0 :(
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing
(packed_fp16_to_fp32)
// the branch splits the basic blocks and prevents the overlap of memory access and computing
// add fake dependency of loaded data so NVCC will not skip the load
//
(packed_fp16_to_fp32)
add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_act
[
k
])
{
for
(
auto
&&
data
:
lora_act
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
}
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
unused_var
(
dummy
,
alwaysfalse
);
...
@@ -220,21 +221,20 @@ public:
...
@@ -220,21 +221,20 @@ public:
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
}
}
__device__
__forceinline__
__device__
__forceinline__
void
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
CHECK_NAN
(
fpsum
,
"fpsum"
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
apply_lora_up
(
apply_lora_up
(
fpsum
,
fpsum
,
args
.
lora_act
+
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
scales
,
args
.
scales
,
args
.
rank
,
args
.
rank
,
args
.
alwaysfalse
args
.
alwaysfalse
);
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
}
}
...
@@ -250,16 +250,16 @@ public:
...
@@ -250,16 +250,16 @@ public:
bool
alwaysfalse
;
bool
alwaysfalse
;
};
};
__device__
__forceinline__
__device__
__forceinline__
static
void
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
// we have rank > 0
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
...
@@ -270,11 +270,11 @@ public:
...
@@ -270,11 +270,11 @@ public:
lora_act_warp
lora_act
;
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
...
@@ -294,14 +294,14 @@ public:
...
@@ -294,14 +294,14 @@ public:
int
dummy
=
0
;
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
break
;
}
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
if
(
alwaysfalse
)
{
...
@@ -324,38 +324,33 @@ public:
...
@@ -324,38 +324,33 @@ public:
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
unused_var
(
dummy
,
alwaysfalse
);
}
}
__device__
__forceinline__
__device__
__forceinline__
void
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
apply_lora_down
(
apply_lora_down
(
fpsum
,
fpsum
,
args
.
lora_act
+
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
rank
,
args
.
rank
,
args
.
alwaysfalse
args
.
alwaysfalse
);
);
}
}
};
};
};
};
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/mma.cuh
View file @
37a27712
...
@@ -7,183 +7,169 @@
...
@@ -7,183 +7,169 @@
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
namespace
mma_helper
{
struct
f32
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
static
constexpr
const
char
value
[]
=
"f32"
;
};
};
struct
f16
{
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
static
constexpr
const
char
value
[]
=
"f16"
;
};
};
struct
bf16
{
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
static
constexpr
const
char
value
[]
=
"bf16"
;
};
};
struct
s32
{
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
static
constexpr
const
char
value
[]
=
"s32"
;
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
};
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%6, %7},"
"{%8, %9};
\n
"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1;"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{tmp0, tmp1},"
"{%2, %3},"
"{%2, %3},"
"{%6},"
"{%6},"
"{%8, %9};
\n
"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%4, %5},"
"{%4, %5},"
"{%7},"
"{%7},"
"{tmp0, tmp1};"
"{tmp0, tmp1};"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#endif
#endif
return
d
;
return
d
;
}
}
template
<
bool
is_bf16
>
template
<
bool
is_bf16
>
__device__
__forceinline__
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"r"
(
a
.
y
),
:
"r"
(
a
.
z
),
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
b
.
x
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"r"
(
b
.
y
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
)
"r"
(
c
.
x
),
);
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"C"
(
mma_helper
::
f16bf16
<
is_bf16
>::
value
));
#else
#else
static_assert
(
!
is_bf16
);
static_assert
(
!
is_bf16
);
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%4, %5},"
"{%8},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%6, %7},"
"{%9},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#endif
#endif
return
d
;
return
d
;
}
}
template
<
typename
AType
,
typename
BType
>
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"r"
(
a
.
y
),
:
"r"
(
a
.
z
),
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
b
.
x
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"r"
(
b
.
y
),
"n"
(
K
),
"r"
(
c
.
x
),
"C"
(
AType
::
value
),
"r"
(
c
.
y
),
"C"
(
BType
::
value
)
"r"
(
c
.
z
),
);
"r"
(
c
.
w
),
"n"
(
K
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
));
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{tmp0, tmp1},"
"{%4},"
"{%4},"
"{%8},"
"{%8},"
"{%10, %11};
\n
"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{tmp2, tmp3},"
"{%5},"
"{%5},"
"{%8},"
"{%8},"
"{%12, %13};
\n
"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%0, %1},"
"{%6},"
"{%6},"
"{%9},"
"{%9},"
"{tmp0, tmp1};
\n
"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%2, %3},"
"{%7},"
"{%7},"
"{%9},"
"{%9},"
"{tmp2, tmp3};
\n
"
"{tmp2, tmp3};
\n
"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"r"
(
a
.
y
),
:
"r"
(
a
.
z
),
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
b
.
x
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"r"
(
b
.
y
),
"n"
(
K
/
2
),
"r"
(
c
.
x
),
"C"
(
AType
::
value
),
"r"
(
c
.
y
),
"C"
(
BType
::
value
)
"r"
(
c
.
z
),
);
"r"
(
c
.
w
),
"n"
(
K
/
2
),
"C"
(
AType
::
value
),
"C"
(
BType
::
value
));
#endif
#endif
return
d
;
return
d
;
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/mma_earlycuda.cuh
View file @
37a27712
...
@@ -6,156 +6,118 @@
...
@@ -6,156 +6,118 @@
// cuda 12.4- does not support "C" constraint in inline assembly :(
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
// use explicit specialization for now
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
namespace
mma_helper
{
namespace
mma_helper
{
struct
f32
{
struct
f32
{
static
constexpr
const
char
value
[]
=
"f32"
;
static
constexpr
const
char
value
[]
=
"f32"
;
};
};
struct
f16
{
struct
f16
{
static
constexpr
const
char
value
[]
=
"f16"
;
static
constexpr
const
char
value
[]
=
"f16"
;
};
};
struct
bf16
{
struct
bf16
{
static
constexpr
const
char
value
[]
=
"bf16"
;
static
constexpr
const
char
value
[]
=
"bf16"
;
};
};
struct
s32
{
struct
s32
{
static
constexpr
const
char
value
[]
=
"s32"
;
static
constexpr
const
char
value
[]
=
"s32"
;
};
};
struct
s4
{
struct
s4
{
static
constexpr
const
char
value
[]
=
"s4"
;
static
constexpr
const
char
value
[]
=
"s4"
;
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
};
struct
u4
{
static
constexpr
const
char
value
[]
=
"u4"
;
};
template
<
bool
is_bf16
>
using
f16bf16
=
std
::
conditional_t
<
is_bf16
,
bf16
,
f16
>
;
template
<
bool
is_unsigned
>
using
s4u4
=
std
::
conditional_t
<
is_unsigned
,
u4
,
s4
>
;
};
// namespace mma_helper
__device__
__forceinline__
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%6, %7},"
"{%8, %9};
\n
"
"{%8, %9};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1;"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{tmp0, tmp1},"
"{%2, %3},"
"{%2, %3},"
"{%6},"
"{%6},"
"{%8, %9};
\n
"
"{%8, %9};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%0, %1},"
"{%4, %5},"
"{%4, %5},"
"{%7},"
"{%7},"
"{tmp0, tmp1};"
"{tmp0, tmp1};"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
)
);
#endif
#endif
return
d
;
return
d
;
}
}
template
<
bool
is_bf16
>
template
<
bool
is_bf16
>
__device__
__forceinline__
__device__
__forceinline__
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
static
uint4
mma_m16n8k16_f32f16f16f32
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
template
<
>
__device__
__forceinline__
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
return
d
;
return
d
;
}
}
#endif
#endif
template
<
>
template
<
>
__device__
__forceinline__
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%4, %5},"
"{%8},"
"{%8},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%6, %7},"
"{%9},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
)
);
#endif
#endif
return
d
;
return
d
;
}
}
template
<
typename
AType
,
typename
BType
>
template
<
typename
AType
,
typename
BType
>
__device__
__forceinline__
__device__
__forceinline__
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
static
uint4
mma_m16n8kx_s32common
(
uint4
a
,
uint2
b
,
uint4
c
)
=
delete
;
template
<
>
template
<
>
__device__
__forceinline__
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
mma_m16n8kx_s32common
<
mma_helper
::
s4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
...
@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
...
@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{tmp0, tmp1},"
"{%4},"
"{%4},"
"{%8},"
"{%8},"
"{%10, %11};
\n
"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{tmp2, tmp3},"
"{%5},"
"{%5},"
"{%8},"
"{%8},"
"{%12, %13};
\n
"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%0, %1},"
"{%6},"
"{%6},"
"{%9},"
"{%9},"
"{tmp0, tmp1};
\n
"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%2, %3},"
"{%7},"
"{%7},"
"{%9},"
"{%9},"
"{tmp2, tmp3};
\n
"
"{tmp2, tmp3};
\n
"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"r"
(
a
.
y
),
:
"r"
(
a
.
z
),
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
b
.
x
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"r"
(
b
.
y
),
"n"
(
K
/
2
)
"r"
(
c
.
x
),
);
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
#endif
return
d
;
return
d
;
}
}
template
<
>
template
<
>
__device__
__forceinline__
__device__
__forceinline__
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
mma_m16n8kx_s32common
<
mma_helper
::
u4
,
mma_helper
::
s4
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
...
@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
...
@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
"{%10, %11, %12, %13};
\n
"
:
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
)
);
#else
#else
asm
volatile
(
asm
volatile
(
"{"
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{tmp0, tmp1},"
"{%4},"
"{%4},"
"{%8},"
"{%8},"
"{%10, %11};
\n
"
"{%10, %11};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{tmp2, tmp3},"
"{%5},"
"{%5},"
"{%8},"
"{%8},"
"{%12, %13};
\n
"
"{%12, %13};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%0, %1},"
"{%6},"
"{%6},"
"{%9},"
"{%9},"
"{tmp0, tmp1};
\n
"
"{tmp0, tmp1};
\n
"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%2, %3},"
"{%7},"
"{%7},"
"{%9},"
"{%9},"
"{tmp2, tmp3};
\n
"
"{tmp2, tmp3};
\n
"
"}
\n
"
"}
\n
"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
:
:
"r"
(
a
.
x
),
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
"r"
(
a
.
y
),
:
"r"
(
a
.
z
),
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
b
.
x
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"r"
(
b
.
y
),
"n"
(
K
/
2
)
"r"
(
c
.
x
),
);
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
/
2
));
#endif
#endif
return
d
;
return
d
;
}
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/zgemm.h
View file @
37a27712
...
@@ -5,50 +5,55 @@
...
@@ -5,50 +5,55 @@
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
void
gemm_w4a4
(
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
,
bool
fuse_silu
,
bool
fp4
,
bool
fp4
,
float
alpha
,
float
alpha
,
Tensor
wcscales
,
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
int
attn_tokens
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
void
gemm_w8a8
(
Tensor
act
,
// [M, K]
Tensor
wgt
,
// [N, K]
Tensor
wgt
,
// [N, K]
Tensor
out
,
// [M, N]
Tensor
out
,
// [M, N]
Tensor
ascales
,
// [1, M]
Tensor
ascales
,
// [1, M]
Tensor
wscales
,
// [1, N]
Tensor
wscales
,
// [1, N]
Tensor
bias
// packed ws [N]
Tensor
bias
// packed ws [N]
);
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
void
quantize_w8a8_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
bool
fuse_glu
);
...
@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
...
@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// Tensor wscales // [1, N]
// );
// );
void
attention_fp16
(
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
);
float
scale
);
// EXPERIMENTAL, for sm_75
// EXPERIMENTAL, for sm_75
void
set_faster_i2f_mode
(
std
::
string
mode
);
void
set_faster_i2f_mode
(
std
::
string
mode
);
...
@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode);
...
@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode);
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
);
void
test_pack_qkv
(
Tensor
input
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
);
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/layernorm.cpp
View file @
37a27712
#include "layernorm.h"
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"
#include "kernels/layernorm_kernels.h"
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
LayerNorm
::
LayerNorm
(
int
hidden_size
,
float
eps
,
bool
elementwise_affine
,
Tensor
::
ScalarType
dtype
,
Device
device
)
hidden_size
(
hidden_size
),
eps
(
eps
)
:
hidden_size
(
hidden_size
),
eps
(
eps
)
{
{
if
(
elementwise_affine
)
{
if
(
elementwise_affine
)
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
bias
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
bias
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
}
}
registerParams
registerParams
(
weight
,
"weight"
)(
bias
,
"bias"
);
(
weight
,
"weight"
)
(
bias
,
"bias"
)
;
}
}
Tensor
LayerNorm
::
forward
(
Tensor
x
)
{
Tensor
LayerNorm
::
forward
(
Tensor
x
)
{
...
@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
...
@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
return
out
;
return
out
;
}
}
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
void
RMSNormGeneral
::
forward_with_act_sum
(
Tensor
x
,
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general_fuse_sum
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_sum_buffer
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
}
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
void
RMSNormGeneral
::
forward_wo_act_sum
(
Tensor
x
,
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
rms_norm_general
(
quantized_hidden_states_buffer
,
x
,
this
->
weight
,
quantized_scale_buffer
,
variance_epsilon
,
use_per_token_quant
);
}
}
src/layernorm.h
View file @
37a27712
...
@@ -20,9 +20,8 @@ private:
...
@@ -20,9 +20,8 @@ private:
class
RMSNorm
:
public
Module
{
class
RMSNorm
:
public
Module
{
public:
public:
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
RMSNorm
(
int
hidden_size
,
float
eps
,
bool
use_quant
,
Tensor
::
ScalarType
dtype
,
Device
device
)
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
:
use_quant
(
use_quant
),
variance_epsilon
(
eps
)
{
{
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
weight
=
Tensor
::
allocate
({
hidden_size
},
dtype
,
device
);
registerParams
(
weight
,
"weight"
);
registerParams
(
weight
,
"weight"
);
}
}
...
@@ -36,13 +35,16 @@ public:
...
@@ -36,13 +35,16 @@ public:
class
RMSNormGeneral
{
class
RMSNormGeneral
{
friend
class
LlamaDecoderLayer
;
friend
class
LlamaDecoderLayer
;
public:
public:
RMSNormGeneral
(
int
hidden_size
,
bool
act_sum
,
float
eps
,
bool
use_per_token_quant
,
Device
device
)
RMSNormGeneral
(
int
hidden_size
,
bool
act_sum
,
float
eps
,
bool
use_per_token_quant
,
Device
device
)
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
:
act_sum
(
act_sum
),
use_per_token_quant
(
use_per_token_quant
),
variance_epsilon
(
eps
)
{
{
this
->
weight
=
Tensor
::
ones
({
hidden_size
},
Tensor
::
FP32
,
device
);
this
->
weight
=
Tensor
::
ones
({
hidden_size
},
Tensor
::
FP32
,
device
);
}
}
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
void
forward
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
)
{
if
(
act_sum
)
{
if
(
act_sum
)
{
forward_with_act_sum
(
x
,
quantized_hidden_states_buffer
,
quantized_scale_buffer
,
quantized_sum_buffer
);
forward_with_act_sum
(
x
,
quantized_hidden_states_buffer
,
quantized_scale_buffer
,
quantized_sum_buffer
);
}
else
{
}
else
{
...
@@ -51,12 +53,18 @@ public:
...
@@ -51,12 +53,18 @@ public:
}
}
private:
private:
void
forward_with_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_with_act_sum
(
Tensor
x
,
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
void
forward_wo_act_sum
(
Tensor
x
,
Tensor
quantized_hidden_states_buffer
,
Tensor
quantized_scale_buffer
,
Tensor
quantized_sum_buffer
);
private:
private:
const
bool
act_sum
;
const
bool
act_sum
;
const
bool
use_per_token_quant
;
const
bool
use_per_token_quant
;
const
float
variance_epsilon
;
const
float
variance_epsilon
;
Tensor
weight
;
Tensor
weight
;
};
};
\ No newline at end of file
src/pytorch_compat.h
View file @
37a27712
...
@@ -4,103 +4,106 @@
...
@@ -4,103 +4,106 @@
#include "Tensor.h"
#include "Tensor.h"
namespace
pytorch_compat
{
namespace
pytorch_compat
{
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
inline
void
TORCH_CHECK
(
bool
cond
,
const
std
::
string
&
msg
=
""
)
{
assert
(
cond
);
assert
(
cond
);
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
inline
void
C10_CUDA_CHECK
(
T
ret
)
{
return
checkCUDA
(
ret
);
return
checkCUDA
(
ret
);
}
}
namespace
at
{
namespace
at
{
using
::
Tensor
;
using
::
Tensor
;
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
struct
Generator
{
Generator
()
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
std
::
mutex
mutex_
;
};
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
struct
StreamWrapper
{
cudaStream_t
st
;
cudaStream_t
stream
()
const
{
return
st
;
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
}
struct
CUDAGuard
{
int
dev
;
};
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
return
Generator
();
}
}
}
using
CUDAGeneratorImpl
=
Generator
;
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
namespace
torch
{
constexpr
auto
kFloat32
=
Tensor
::
FP32
;
using
at
::
kFloat32
;
constexpr
auto
kFloat
=
Tensor
::
FP32
;
using
at
::
kFloat
;
constexpr
auto
kFloat16
=
Tensor
::
FP16
;
using
at
::
kFloat16
;
constexpr
auto
kBFloat16
=
Tensor
::
BF16
;
using
at
::
kBFloat16
;
constexpr
auto
kInt32
=
Tensor
::
INT32
;
using
at
::
kInt32
;
constexpr
auto
kInt64
=
Tensor
::
INT64
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
struct
Generator
{
Generator
()
{
using
IntArrayRef
=
std
::
vector
<
int
>
;
throw
std
::
runtime_error
(
"Not implemented"
);
using
TensorOptions
=
Tensor
::
TensorOptions
;
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
return
Tensor
::
empty_like
(
tensor
);
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
());
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
()).
zero_
();
}
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
}
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
int
a
;
int
b
;
};
}
}
}
std
::
mutex
mutex_
;
};
namespace
cuda
{
using
::
getCurrentDeviceProperties
;
namespace
c10
{
struct
StreamWrapper
{
using
std
::
optional
;
cudaStream_t
st
;
cudaStream_t
stream
()
const
{
return
st
;
}
}
};
inline
StreamWrapper
getCurrentCUDAStream
()
{
return
StreamWrapper
(
::
getCurrentCUDAStream
());
}
struct
CUDAGuard
{
int
dev
;
};
namespace
detail
{
inline
Generator
getDefaultCUDAGenerator
()
{
return
Generator
();
}
}
// namespace detail
}
// namespace cuda
using
CUDAGeneratorImpl
=
Generator
;
template
<
typename
T
>
std
::
unique_ptr
<
Generator
>
get_generator_or_default
(
std
::
optional
<
Generator
>
gen
,
T
gen2
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
// namespace at
namespace
torch
{
using
at
::
kFloat32
;
using
at
::
kFloat
;
using
at
::
kFloat16
;
using
at
::
kBFloat16
;
using
at
::
kInt32
;
using
at
::
kInt64
;
constexpr
Device
kCUDA
=
Device
::
cuda
();
using
IntArrayRef
=
std
::
vector
<
int
>
;
using
TensorOptions
=
Tensor
::
TensorOptions
;
inline
Tensor
empty_like
(
const
Tensor
&
tensor
)
{
return
Tensor
::
empty_like
(
tensor
);
}
inline
Tensor
empty
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
());
}
inline
Tensor
zeros
(
TensorShape
shape
,
Tensor
::
TensorOptions
options
)
{
return
Tensor
::
empty
(
shape
,
options
.
dtype
(),
options
.
device
()).
zero_
();
}
namespace
nn
{
namespace
functional
{
using
PadFuncOptions
=
std
::
vector
<
int
>
;
inline
Tensor
pad
(
Tensor
x
,
PadFuncOptions
options
)
{
throw
std
::
runtime_error
(
"Not implemented"
);
}
}
// namespace functional
}
// namespace nn
namespace
indexing
{
constexpr
int
None
=
0
;
struct
Slice
{
int
a
;
int
b
;
};
}
// namespace indexing
}
// namespace torch
namespace
c10
{
using
std
::
optional
;
}
}
}
// namespace pytorch_compat
tests/README.md
View file @
37a27712
...
@@ -35,7 +35,7 @@ To test visual output correctness, you can:
...
@@ -35,7 +35,7 @@ To test visual output correctness, you can:
lpips
=
compute_lpips
(
dir1, dir2
)
lpips
=
compute_lpips
(
dir1, dir2
)
```
```
Here,
`dir1`
should point to the directory containing the reference images, and
`dir2`
should contain the images generated by your method.
Here,
`dir1`
should point to the directory containing the reference images, and
`dir2`
should contain the images generated by your method.
### Setting the LPIPS Threshold
### Setting the LPIPS Threshold
...
@@ -43,4 +43,4 @@ To pass the test, the LPIPS score must be below a predefined threshold—typical
...
@@ -43,4 +43,4 @@ To pass the test, the LPIPS score must be below a predefined threshold—typical
## Acknowledgments
## Acknowledgments
This contribution guide is adapted from
[
SGLang
](
https://github.com/sgl-project/sglang/tree/main/test
)
. We thank them for the inspiration.
This contribution guide is adapted from
[
SGLang
](
https://github.com/sgl-project/sglang/tree/main/test
)
. We thank them for the inspiration.
\ No newline at end of file
tests/data/__init__.py
View file @
37a27712
...
@@ -3,7 +3,6 @@ import random
...
@@ -3,7 +3,6 @@ import random
import
datasets
import
datasets
import
yaml
import
yaml
from
huggingface_hub
import
snapshot_download
from
nunchaku.utils
import
fetch_or_download
from
nunchaku.utils
import
fetch_or_download
...
...
tests/flux/test_flux_cache.py
View file @
37a27712
import
pytest
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
from
.utils
import
run_test
...
@@ -8,7 +9,7 @@ from .utils import run_test
...
@@ -8,7 +9,7 @@ from .utils import run_test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips"
,
[
[
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
if
get_precision
()
==
"int4"
else
0.1
44
),
(
0.12
,
1024
,
1024
,
30
,
None
,
1
,
0.212
if
get_precision
()
==
"int4"
else
0.1
61
),
],
],
)
)
def
test_flux_dev_cache
(
def
test_flux_dev_cache
(
...
...
tests/flux/test_flux_dev.py
View file @
37a27712
import
pytest
import
pytest
from
nunchaku.utils
import
get_precision
,
is_turing
from
nunchaku.utils
import
get_precision
,
is_turing
from
.utils
import
run_test
from
.utils
import
run_test
...
@@ -9,7 +10,7 @@ from .utils import run_test
...
@@ -9,7 +10,7 @@ from .utils import run_test
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips"
,
[
[
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.139
if
get_precision
()
==
"int4"
else
0.146
),
(
1024
,
1024
,
50
,
"flashattn2"
,
False
,
0.139
if
get_precision
()
==
"int4"
else
0.146
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
if
get_precision
()
==
"int4"
else
0.1
33
),
(
2048
,
512
,
25
,
"nunchaku-fp16"
,
False
,
0.168
if
get_precision
()
==
"int4"
else
0.1
56
),
],
],
)
)
def
test_flux_dev
(
def
test_flux_dev
(
...
...
Prev
1
…
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment