Commit 18b1d183 authored by Xiaodong Ye's avatar Xiaodong Ye
Browse files

musa: support bf16


Signed-off-by: default avatarXiaodong Ye <xiaodong.ye@mthreads.com>
parent 94ab2de3
#pragma once #pragma once
#include <musa_runtime.h> #include <musa_runtime.h>
#include <musa_bf16.h>
#define cudaLaunchHostFunc musaLaunchHostFunc #define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t #define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t #define cudaHostFn_t musaHostFn_t
\ No newline at end of file #define nv_bfloat16 mt_bfloat16
\ No newline at end of file
...@@ -350,6 +350,7 @@ elif MUSA_HOME is not None: ...@@ -350,6 +350,7 @@ elif MUSA_HOME is not None:
"at::cuda": "at::musa", "at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", "#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
"nv_bfloat16": "mt_bfloat16",
}).run() }).run()
ops_module = MUSAExtension('KTransformersOps', [ ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', 'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment