Commit 0c698cda authored by caihl's avatar caihl
Browse files

adapt to vllm-plugin-FL

parent aadf7b41
......@@ -90,7 +90,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(USE_ROCM))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
......
......@@ -105,7 +105,7 @@ DINLINE half& assign_add(half& a, half b) {
}
DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(USE_ROCM))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
......@@ -638,4 +638,4 @@ class CustomAllreduce {
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm
\ No newline at end of file
} // namespace vllm
......@@ -41,11 +41,11 @@
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
//__device__ inline void __syncwarp() {
// __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
// __builtin_amdgcn_wave_barrier();
// __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
//}
#endif
#else
#define FINAL_MASK 0xffffffff
......@@ -425,4 +425,4 @@ void fused_qk_norm_rope(
stream);
});
});
}
\ No newline at end of file
}
......@@ -25,6 +25,8 @@ elif current_platform.is_rocm():
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
else:
from flash_attn import flash_attn_varlen_func
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
......
......@@ -143,15 +143,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if (
ca_comm is not None
and not ca_comm.disabled
and ca_comm.should_custom_ar(input_)
):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
#ca_comm = self.ca_comm
#if (
# ca_comm is not None
# and not ca_comm.disabled
# and ca_comm.should_custom_ar(input_)
#):
# out = ca_comm.custom_all_reduce(input_)
# assert out is not None
# return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment