Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c698cda
Commit
0c698cda
authored
Mar 05, 2026
by
caihl
Browse files
adapt to vllm-plugin-FL
parent
aadf7b41
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
18 deletions
+20
-18
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+1
-1
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+2
-2
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+6
-6
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+2
-0
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+9
-9
No files found.
csrc/custom_all_reduce.cu
View file @
0c698cda
...
@@ -90,7 +90,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
...
@@ -90,7 +90,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
break
;
}
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
|| defined(USE_ROCM)
)
case
at
::
ScalarType
::
BFloat16
:
{
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
...
...
csrc/custom_all_reduce.cuh
View file @
0c698cda
...
@@ -105,7 +105,7 @@ DINLINE half& assign_add(half& a, half b) {
...
@@ -105,7 +105,7 @@ DINLINE half& assign_add(half& a, half b) {
}
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
|| defined(USE_ROCM)
)
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
template
<
>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
...
...
csrc/fused_qknorm_rope_kernel.cu
View file @
0c698cda
...
@@ -41,11 +41,11 @@
...
@@ -41,11 +41,11 @@
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__
inline
void
__syncwarp
()
{
//
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence
(
__ATOMIC_RELEASE
,
"wavefront"
);
//
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier
();
//
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence
(
__ATOMIC_ACQUIRE
,
"wavefront"
);
//
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
//
}
#endif
#endif
#else
#else
#define FINAL_MASK 0xffffffff
#define FINAL_MASK 0xffffffff
...
...
vllm/attention/utils/fa_utils.py
View file @
0c698cda
...
@@ -25,6 +25,8 @@ elif current_platform.is_rocm():
...
@@ -25,6 +25,8 @@ elif current_platform.is_rocm():
"Rocm platform requires upstream flash-attn "
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
"to be installed. Please install flash-attn first."
)
from
e
)
from
e
else
:
from
flash_attn
import
flash_attn_varlen_func
def
get_flash_attn_version
(
requires_alibi
:
bool
=
False
)
->
int
|
None
:
def
get_flash_attn_version
(
requires_alibi
:
bool
=
False
)
->
int
|
None
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
0c698cda
...
@@ -143,15 +143,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -143,15 +143,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
out
=
qr_comm
.
quick_all_reduce
(
input_
)
out
=
qr_comm
.
quick_all_reduce
(
input_
)
assert
out
is
not
None
assert
out
is
not
None
return
out
return
out
ca_comm
=
self
.
ca_comm
#
ca_comm = self.ca_comm
if
(
#
if (
ca_comm
is
not
None
#
ca_comm is not None
and
not
ca_comm
.
disabled
#
and not ca_comm.disabled
and
ca_comm
.
should_custom_ar
(
input_
)
#
and ca_comm.should_custom_ar(input_)
):
#
):
out
=
ca_comm
.
custom_all_reduce
(
input_
)
#
out = ca_comm.custom_all_reduce(input_)
assert
out
is
not
None
#
assert out is not None
return
out
#
return out
symm_mem_comm
=
self
.
symm_mem_comm
symm_mem_comm
=
self
.
symm_mem_comm
if
symm_mem_comm
is
not
None
and
symm_mem_comm
.
should_use_symm_mem
(
input_
):
if
symm_mem_comm
is
not
None
and
symm_mem_comm
.
should_use_symm_mem
(
input_
):
out
=
symm_mem_comm
.
all_reduce
(
input_
)
out
=
symm_mem_comm
.
all_reduce
(
input_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment