Commit 03ac3784 authored by fengzch's avatar fengzch
Browse files

fix: compile gemv_awq.cu complete

parent fcbab540
......@@ -127,6 +127,7 @@ if __name__ == "__main__":
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++2a", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++2a", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
NVCC_FLAGS = [
"-DDCU_ASM",
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
"-g",
......@@ -172,34 +173,34 @@ if __name__ == "__main__":
*ncond("src/SanaModel.cpp"),
"src/Serialization.cpp",
"src/Module.cpp",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
),
*ncond(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
),
"src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
"src/kernels/zgemm/gemm_w4a4.cu",
"src/kernels/zgemm/gemm_w4a4_test.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
"src/kernels/zgemm/gemm_w8a8.cu",
"src/kernels/zgemm/attention.cu",
"src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu",
"src/kernels/awq/gemm_awq.cu",
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
# *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
# *ncond(
# "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
# ),
# *ncond(
# "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
# ),
# "src/kernels/activation_kernels.cu",
# "src/kernels/layernorm_kernels.cu",
# "src/kernels/misc_kernels.cu",
# "src/kernels/zgemm/gemm_w4a4.cu",
# "src/kernels/zgemm/gemm_w4a4_test.cu",
# "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
# "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
# "src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
# "src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
# "src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
# "src/kernels/zgemm/gemm_w8a8.cu",
# "src/kernels/zgemm/attention.cu",
# "src/kernels/dwconv.cu",
# "src/kernels/gemm_batched.cu",
# "src/kernels/gemm_f16.cu",
# "src/kernels/awq/gemm_awq.cu",
"src/kernels/awq/gemv_awq.cu",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
......
......@@ -66,13 +66,13 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
......
......@@ -30,7 +30,7 @@ struct Generator {
std::mutex mutex_;
};
namespace cuda {
namespace hip {
using ::getCurrentDeviceProperties;
struct StreamWrapper {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment