"tests/git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "7ea81099235fd4ccf8d4b9ba202e76cce40b5cc8"
Commit 9fed1f5d authored by zhuwenwen's avatar zhuwenwen
Browse files

add bf16

parent 3f1166ab
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include "dtype_float16.cuh" #include "dtype_float16.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
// #include "dtype_bfloat16.cuh" #include "dtype_bfloat16.cuh"
// #include "dtype_fp8_e5m2.cuh" // #include "dtype_fp8_e5m2.cuh"
...@@ -734,8 +734,8 @@ void paged_attention_v1( ...@@ -734,8 +734,8 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
// } else if (query.dtype() == at::ScalarType::BFloat16) { } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
...@@ -927,8 +927,8 @@ void paged_attention_v2( ...@@ -927,8 +927,8 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
// } else if (query.dtype() == at::ScalarType::BFloat16) { } else if (query.dtype() == at::ScalarType::BFloat16) {
// CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
} }
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include <map> #include <map>
#include <vector> #include <vector>
// #ifdef USE_ROCM #ifdef USE_ROCM
// #include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
// typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
// #endif #endif
void swap_blocks( void swap_blocks(
torch::Tensor& src, torch::Tensor& src,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment