Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e3e2a92a
"vscode:/vscode.git/clone" did not exist on "0a08d41961220887c97074dcd585e52bba9f6220"
Commit
e3e2a92a
authored
Nov 10, 2024
by
Zhekai Zhang
Browse files
[major] fix running on Windows
parent
db223c25
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
18 deletions
+58
-18
setup.py
setup.py
+11
-10
src/FluxModel.cpp
src/FluxModel.cpp
+1
-1
src/kernels/flash_attn
src/kernels/flash_attn
+0
-1
src/kernels/gemm_w4a4.cu
src/kernels/gemm_w4a4.cu
+46
-6
No files found.
setup.py
View file @
e3e2a92a
...
...
@@ -29,6 +29,7 @@ if __name__ == "__main__":
"third_party/json/include"
,
"third_party/mio/include"
,
"third_party/spdlog/include"
,
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn"
,
]
INCLUDE_DIRS
=
[
ROOT_DIR
+
"/"
+
dir
for
dir
in
INCLUDE_DIRS
]
...
...
@@ -89,14 +90,14 @@ if __name__ == "__main__":
"src/Linear.cpp"
,
*
ncond
(
"src/FluxModel.cpp"
),
"src/Serialization.cpp"
,
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_hdim64_fp16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_hdim64_bf16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_hdim128_fp16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_hdim128_bf16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
),
*
ncond
(
"
src/kernels/flash
_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_hdim64_fp16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_hdim64_bf16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_hdim128_fp16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_hdim128_bf16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
),
"src/kernels/activation_kernels.cu"
,
"src/kernels/layernorm_kernels.cu"
,
"src/kernels/misc_kernels.cu"
,
...
...
@@ -104,8 +105,8 @@ if __name__ == "__main__":
"src/kernels/gemm_batched.cu"
,
"src/kernels/gemm_f16.cu"
,
"src/kernels/awq/gemv_awq.cu"
,
*
ncond
(
"
src/kernels/flash
_attn/flash_api.cpp"
),
*
ncond
(
"
src/kernels/flash
_attn/flash_api_adapter.cpp"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/flash_api.cpp"
),
*
ncond
(
"
third_party/Block-Sparse-Attention/csrc/block_sparse
_attn/flash_api_adapter.cpp"
),
],
extra_compile_args
=
{
"gcc"
:
GCC_FLAGS
,
"msvc"
:
MSVC_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
"nvcc_msvc"
:
NVCC_MSVC_FLAGS
},
include_dirs
=
INCLUDE_DIRS
,
...
...
src/FluxModel.cpp
View file @
e3e2a92a
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/flash_attn/flash_api.h"
#include "kernels/gemm_batched.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
...
...
src/kernels/flash_attn
deleted
120000 → 0
View file @
db223c25
../../third_party/Block-Sparse-Attention/csrc/block_sparse_attn
\ No newline at end of file
src/kernels/gemm_w4a4.cu
View file @
e3e2a92a
...
...
@@ -1631,6 +1631,10 @@ public:
void
apply_bias
(
fpsum_warp
&
fpsum
,
half_t
*
out
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp
b
;
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
...
...
@@ -1884,6 +1888,8 @@ public:
bool
swapBlockXY
,
bool
alwaysfalse
)
{
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
...
...
@@ -2654,6 +2660,24 @@ static void invoke_kernel(T ...args) {
kernel
()(
args
...);
}
template
<
typename
T
>
__global__
static
void
test_sizeof_device
()
{
printf
(
"sizeof on device = %d
\n
"
,
(
int
)
sizeof
(
T
));
}
template
<
typename
T
>
static
void
test_sizeof_host
()
{
printf
(
"sizeof on host = %d
\n
"
,
(
int
)
sizeof
(
T
));
}
template
<
typename
T
>
static
void
test_sizeof
()
{
printf
(
"typeid = %s
\n
"
,
typeid
(
T
).
name
());
test_sizeof_host
<
T
>
();
test_sizeof_device
<
T
><<<
1
,
1
>>>
();
checkCUDA
(
cudaDeviceSynchronize
());
}
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
...
...
@@ -2683,6 +2707,13 @@ void gemm_w4a4(
int
K
=
act
.
shape
[
-
1
]
*
2
;
assert
(
K
==
wgt
.
shape
[
1
]
*
2
);
// spdlog::info("M={} N={} K={}", M, N, K);
// spdlog::info("act at {}", act.data_ptr());
// spdlog::info("wgt at {}", wgt.data_ptr());
// spdlog::info("ascales at {}", ascales.data_ptr());
// spdlog::info("wscales at {}", wscales.data_ptr());
// spdlog::info("bias at {}", bias.data_ptr());
auto
launch
=
[
&
]
<
typename
Epilogue
>
(
Epilogue
::
Arguments
args
)
{
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
...
...
@@ -2692,6 +2723,10 @@ void gemm_w4a4(
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
// test_sizeof<Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
invoke_kernel
<
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>><<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
>>>
(
act
.
data_ptr
<
GEMM
::
packed_act_t
>
(),
wgt
.
data_ptr
<
GEMM
::
packed_wgt_t
>
(),
...
...
@@ -2715,12 +2750,15 @@ void gemm_w4a4(
assert
(
bias
.
numel
()
==
N
);
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
nextArgs
,
{}
});
};
// auto launch_bias = launch;
...
...
@@ -2754,7 +2792,7 @@ void gemm_w4a4(
}
if
(
!
lora_down
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
>
;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
...
...
@@ -2762,7 +2800,8 @@ void gemm_w4a4(
.
scales
=
scales
,
},
midArgs
,
nextArgs
nextArgs
,
{}
});
}
...
...
@@ -2780,7 +2819,7 @@ void gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
...
...
@@ -2792,7 +2831,8 @@ void gemm_w4a4(
.
lora_wgt_down
=
lora_down
.
data_ptr
<
GEMM
::
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
},
nextArgs
nextArgs
,
{}
});
// });
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment