"vscode:/vscode.git/clone" did not exist on "0a08d41961220887c97074dcd585e52bba9f6220"
Commit e3e2a92a authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

[major] fix running on Windows

parent db223c25
......@@ -29,6 +29,7 @@ if __name__ == "__main__":
"third_party/json/include",
"third_party/mio/include",
"third_party/spdlog/include",
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
]
INCLUDE_DIRS = [ROOT_DIR + "/" + dir for dir in INCLUDE_DIRS]
......@@ -89,14 +90,14 @@ if __name__ == "__main__":
"src/Linear.cpp",
*ncond("src/FluxModel.cpp"),
"src/Serialization.cpp",
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"),
*ncond("src/kernels/flash_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"),
"src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
......@@ -104,8 +105,8 @@ if __name__ == "__main__":
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemv_awq.cu",
*ncond("src/kernels/flash_attn/flash_api.cpp"),
*ncond("src/kernels/flash_attn/flash_api_adapter.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
],
extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
include_dirs=INCLUDE_DIRS,
......
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/flash_attn/flash_api.h"
#include "kernels/gemm_batched.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
......
../../third_party/Block-Sparse-Attention/csrc/block_sparse_attn
\ No newline at end of file
......@@ -1631,6 +1631,10 @@ public:
void apply_bias(fpsum_warp &fpsum, half_t *out, int M, int N, int K, const packed_wscale_t *bias) {
const int laneId = threadIdx.x % WARP_SIZE;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp b;
load_wscale(bias, 0, N, b, true);
......@@ -1884,6 +1888,8 @@ public:
bool swapBlockXY,
bool alwaysfalse)
{
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
......@@ -2654,6 +2660,24 @@ static void invoke_kernel(T ...args) {
kernel()(args...);
}
template<typename T>
__global__
static void test_sizeof_device() {
printf("sizeof on device = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof_host() {
printf("sizeof on host = %d\n", (int)sizeof(T));
}
template<typename T>
static void test_sizeof() {
printf("typeid = %s\n", typeid(T).name());
test_sizeof_host<T>();
test_sizeof_device<T><<<1, 1>>>();
checkCUDA(cudaDeviceSynchronize());
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
......@@ -2683,6 +2707,13 @@ void gemm_w4a4(
int K = act.shape[-1] * 2;
assert(K == wgt.shape[1] * 2);
// spdlog::info("M={} N={} K={}", M, N, K);
// spdlog::info("act at {}", act.data_ptr());
// spdlog::info("wgt at {}", wgt.data_ptr());
// spdlog::info("ascales at {}", ascales.data_ptr());
// spdlog::info("wscales at {}", wscales.data_ptr());
// spdlog::info("bias at {}", bias.data_ptr());
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
......@@ -2692,6 +2723,10 @@ void gemm_w4a4(
}
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
// test_sizeof<Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
invoke_kernel<GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
......@@ -2715,12 +2750,15 @@ void gemm_w4a4(
assert(bias.numel() == N);
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs
nextArgs,
{}
});
};
// auto launch_bias = launch;
......@@ -2754,7 +2792,7 @@ void gemm_w4a4(
}
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
......@@ -2762,7 +2800,8 @@ void gemm_w4a4(
.scales = scales,
},
midArgs,
nextArgs
nextArgs,
{}
});
}
......@@ -2780,7 +2819,7 @@ void gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue>;
using Epilogue = GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
......@@ -2792,7 +2831,8 @@ void gemm_w4a4(
.lora_wgt_down = lora_down.data_ptr<GEMM::packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs
nextArgs,
{}
});
// });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment