Commit 0c698cda authored by caihl's avatar caihl
Browse files

adapt to vllm-plugin-FL

parent aadf7b41
...@@ -90,7 +90,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, ...@@ -90,7 +90,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
reinterpret_cast<half*>(out.data_ptr()), out.numel()); reinterpret_cast<half*>(out.data_ptr()), out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(USE_ROCM))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer), stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
......
...@@ -105,7 +105,7 @@ DINLINE half& assign_add(half& a, half b) { ...@@ -105,7 +105,7 @@ DINLINE half& assign_add(half& a, half b) {
} }
DINLINE float& assign_add(float& a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(USE_ROCM))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <> template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
......
...@@ -41,11 +41,11 @@ ...@@ -41,11 +41,11 @@
#if defined(HIP_VERSION) && HIP_VERSION < 70000000 #if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below // On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0 // implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() { //__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier(); // __builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); // __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
} //}
#endif #endif
#else #else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
......
...@@ -25,6 +25,8 @@ elif current_platform.is_rocm(): ...@@ -25,6 +25,8 @@ elif current_platform.is_rocm():
"Rocm platform requires upstream flash-attn " "Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first." "to be installed. Please install flash-attn first."
) from e ) from e
else:
from flash_attn import flash_attn_varlen_func
def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
......
...@@ -143,15 +143,15 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -143,15 +143,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = qr_comm.quick_all_reduce(input_) out = qr_comm.quick_all_reduce(input_)
assert out is not None assert out is not None
return out return out
ca_comm = self.ca_comm #ca_comm = self.ca_comm
if ( #if (
ca_comm is not None # ca_comm is not None
and not ca_comm.disabled # and not ca_comm.disabled
and ca_comm.should_custom_ar(input_) # and ca_comm.should_custom_ar(input_)
): #):
out = ca_comm.custom_all_reduce(input_) # out = ca_comm.custom_all_reduce(input_)
assert out is not None # assert out is not None
return out # return out
symm_mem_comm = self.symm_mem_comm symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_): if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_) out = symm_mem_comm.all_reduce(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment