Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
bd429ffe
Unverified
Commit
bd429ffe
authored
Jun 25, 2025
by
Shangyan Zhou
Committed by
GitHub
Jun 25, 2025
Browse files
Support bias. (#257)
* Support bias. * Fix. * Fix style.
parent
b80e55e2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
16 deletions
+101
-16
csrc/deep_ep.cpp
csrc/deep_ep.cpp
+28
-4
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+2
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+2
-0
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+29
-5
csrc/kernels/intranode.cu
csrc/kernels/intranode.cu
+18
-1
deep_ep/buffer.py
deep_ep/buffer.py
+18
-4
tests/test_internode.py
tests/test_internode.py
+4
-2
No files found.
csrc/deep_ep.cpp
View file @
bd429ffe
...
@@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -526,6 +526,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
Buffer
::
intranode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
Buffer
::
intranode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_0
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_1
,
const
torch
::
Tensor
&
src_idx
,
const
torch
::
Tensor
&
rank_prefix_matrix
,
const
torch
::
Tensor
&
channel_prefix_matrix
,
const
torch
::
Tensor
&
src_idx
,
const
torch
::
Tensor
&
rank_prefix_matrix
,
const
torch
::
Tensor
&
channel_prefix_matrix
,
const
torch
::
Tensor
&
send_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
const
torch
::
Tensor
&
send_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
)
{
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
EP_HOST_ASSERT
(
x
.
dim
()
==
2
and
x
.
is_contiguous
());
...
@@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -581,6 +582,17 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
num_channels
,
num_recv_tokens
,
num_channels
*
num_ranks
*
2
,
num_channels
,
num_recv_tokens
,
num_channels
*
num_ranks
*
2
,
barrier_signal_ptrs_gpu
,
rank
,
num_ranks
,
barrier_signal_ptrs_gpu
,
rank
,
num_ranks
,
comm_stream
);
comm_stream
);
// Assign bias pointers
auto
bias_opts
=
std
::
vector
<
std
::
optional
<
torch
::
Tensor
>>
({
bias_0
,
bias_1
});
void
*
bias_ptrs
[
2
]
=
{
nullptr
,
nullptr
};
for
(
int
i
=
0
;
i
<
2
;
++
i
)
if
(
bias_opts
[
i
].
has_value
())
{
auto
bias
=
bias_opts
[
i
].
value
();
EP_HOST_ASSERT
(
bias
.
dim
()
==
2
and
bias
.
is_contiguous
());
EP_HOST_ASSERT
(
bias
.
scalar_type
()
==
x
.
scalar_type
());
EP_HOST_ASSERT
(
bias
.
size
(
0
)
==
num_recv_tokens
and
bias
.
size
(
1
)
==
hidden
);
bias_ptrs
[
i
]
=
bias
.
data_ptr
();
}
// Combine data
// Combine data
auto
recv_x
=
torch
::
empty
({
num_recv_tokens
,
hidden
},
x
.
options
());
auto
recv_x
=
torch
::
empty
({
num_recv_tokens
,
hidden
},
x
.
options
());
...
@@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -591,7 +603,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
<=
num_nvl_bytes
);
<=
num_nvl_bytes
);
intranode
::
combine
(
at
::
cuda
::
ScalarTypeToCudaDataType
(
x
.
scalar_type
()),
intranode
::
combine
(
at
::
cuda
::
ScalarTypeToCudaDataType
(
x
.
scalar_type
()),
recv_x
.
data_ptr
(),
recv_topk_weights_ptr
,
recv_x
.
data_ptr
(),
recv_topk_weights_ptr
,
x
.
data_ptr
(),
topk_weights_ptr
,
x
.
data_ptr
(),
topk_weights_ptr
,
bias_ptrs
[
0
],
bias_ptrs
[
1
],
src_idx
.
data_ptr
<
int
>
(),
rank_prefix_matrix
.
data_ptr
<
int
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
src_idx
.
data_ptr
<
int
>
(),
rank_prefix_matrix
.
data_ptr
<
int
>
(),
channel_prefix_matrix
.
data_ptr
<
int
>
(),
send_head
.
data_ptr
<
int
>
(),
num_tokens
,
num_recv_tokens
,
hidden
,
num_topk
,
send_head
.
data_ptr
<
int
>
(),
num_tokens
,
num_recv_tokens
,
hidden
,
num_topk
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
buffer_ptrs_gpu
,
rank
,
num_ranks
,
...
@@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -607,7 +619,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
t
.
record_stream
(
compute_stream
);
t
.
record_stream
(
compute_stream
);
}
}
for
(
auto
&
to
:
{
topk_weights
,
recv_topk_weights
})
{
for
(
auto
&
to
:
{
topk_weights
,
recv_topk_weights
,
bias_0
,
bias_1
})
{
to
.
has_value
()
?
to
->
record_stream
(
comm_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
comm_stream
)
:
void
();
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
...
@@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
...
@@ -906,6 +918,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
Buffer
::
internode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
Buffer
::
internode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_0
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_1
,
const
torch
::
Tensor
&
src_meta
,
const
torch
::
Tensor
&
is_combined_token_in_rank
,
const
torch
::
Tensor
&
src_meta
,
const
torch
::
Tensor
&
is_combined_token_in_rank
,
const
torch
::
Tensor
&
rdma_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_rank_prefix_sum
,
const
torch
::
Tensor
&
gbl_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_rank_prefix_sum
,
const
torch
::
Tensor
&
gbl_channel_prefix_matrix
,
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
...
@@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -979,13 +992,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
barrier_signal_ptrs_gpu
,
rank
,
comm_stream
,
barrier_signal_ptrs_gpu
,
rank
,
comm_stream
,
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
),
config
.
get_rdma_buffer_size_hint
(
hidden_int4
*
sizeof
(
int4
),
num_ranks
),
num_nvl_bytes
,
false
,
low_latency_mode
);
num_nvl_bytes
,
false
,
low_latency_mode
);
// Assign bias pointers
auto
bias_opts
=
std
::
vector
<
std
::
optional
<
torch
::
Tensor
>>
({
bias_0
,
bias_1
});
void
*
bias_ptrs
[
2
]
=
{
nullptr
,
nullptr
};
for
(
int
i
=
0
;
i
<
2
;
++
i
)
if
(
bias_opts
[
i
].
has_value
())
{
auto
bias
=
bias_opts
[
i
].
value
();
EP_HOST_ASSERT
(
bias
.
dim
()
==
2
and
bias
.
is_contiguous
());
EP_HOST_ASSERT
(
bias
.
scalar_type
()
==
x
.
scalar_type
());
EP_HOST_ASSERT
(
bias
.
size
(
0
)
==
num_combined_tokens
and
bias
.
size
(
1
)
==
hidden
);
bias_ptrs
[
i
]
=
bias
.
data_ptr
();
}
// Launch data combine
// Launch data combine
auto
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
auto
combined_x
=
torch
::
empty
({
num_combined_tokens
,
hidden
},
x
.
options
());
internode
::
combine
(
at
::
cuda
::
ScalarTypeToCudaDataType
(
x
.
scalar_type
()),
internode
::
combine
(
at
::
cuda
::
ScalarTypeToCudaDataType
(
x
.
scalar_type
()),
combined_x
.
data_ptr
(),
combined_topk_weights_ptr
,
combined_x
.
data_ptr
(),
combined_topk_weights_ptr
,
is_combined_token_in_rank
.
data_ptr
<
bool
>
(),
is_combined_token_in_rank
.
data_ptr
<
bool
>
(),
x
.
data_ptr
(),
topk_weights_ptr
,
x
.
data_ptr
(),
topk_weights_ptr
,
bias_ptrs
[
0
],
bias_ptrs
[
1
],
combined_rdma_head
.
data_ptr
<
int
>
(),
combined_nvl_head
.
data_ptr
<
int
>
(),
combined_rdma_head
.
data_ptr
<
int
>
(),
combined_nvl_head
.
data_ptr
<
int
>
(),
src_meta
.
data_ptr
(),
rdma_channel_prefix_matrix
.
data_ptr
<
int
>
(),
rdma_rank_prefix_sum
.
data_ptr
<
int
>
(),
gbl_channel_prefix_matrix
.
data_ptr
<
int
>
(),
src_meta
.
data_ptr
(),
rdma_channel_prefix_matrix
.
data_ptr
<
int
>
(),
rdma_rank_prefix_sum
.
data_ptr
<
int
>
(),
gbl_channel_prefix_matrix
.
data_ptr
<
int
>
(),
num_tokens
,
num_combined_tokens
,
hidden
,
num_topk
,
num_tokens
,
num_combined_tokens
,
hidden
,
num_topk
,
...
@@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
...
@@ -1004,7 +1028,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
t
.
record_stream
(
compute_stream
);
t
.
record_stream
(
compute_stream
);
}
}
for
(
auto
&
to
:
{
topk_weights
,
combined_topk_weights
})
{
for
(
auto
&
to
:
{
topk_weights
,
combined_topk_weights
,
bias_0
,
bias_1
})
{
to
.
has_value
()
?
to
->
record_stream
(
comm_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
comm_stream
)
:
void
();
if
(
allocate_on_comm_stream
)
if
(
allocate_on_comm_stream
)
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
to
.
has_value
()
?
to
->
record_stream
(
compute_stream
)
:
void
();
...
...
csrc/deep_ep.hpp
View file @
bd429ffe
...
@@ -112,6 +112,7 @@ public:
...
@@ -112,6 +112,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
intranode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
intranode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_0
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_1
,
const
torch
::
Tensor
&
src_idx
,
const
torch
::
Tensor
&
rank_prefix_matrix
,
const
torch
::
Tensor
&
channel_prefix_matrix
,
const
torch
::
Tensor
&
src_idx
,
const
torch
::
Tensor
&
rank_prefix_matrix
,
const
torch
::
Tensor
&
channel_prefix_matrix
,
const
torch
::
Tensor
&
send_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
const
torch
::
Tensor
&
send_head
,
const
Config
&
config
,
std
::
optional
<
EventHandle
>&
previous_event
,
bool
async
,
bool
allocate_on_comm_stream
);
...
@@ -127,6 +128,7 @@ public:
...
@@ -127,6 +128,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
std
::
optional
<
EventHandle
>>
internode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
internode_combine
(
const
torch
::
Tensor
&
x
,
const
std
::
optional
<
torch
::
Tensor
>&
topk_weights
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_0
,
const
std
::
optional
<
torch
::
Tensor
>&
bias_1
,
const
torch
::
Tensor
&
src_meta
,
const
torch
::
Tensor
&
is_combined_token_in_rank
,
const
torch
::
Tensor
&
src_meta
,
const
torch
::
Tensor
&
is_combined_token_in_rank
,
const
torch
::
Tensor
&
rdma_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_rank_prefix_sum
,
const
torch
::
Tensor
&
gbl_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_channel_prefix_matrix
,
const
torch
::
Tensor
&
rdma_rank_prefix_sum
,
const
torch
::
Tensor
&
gbl_channel_prefix_matrix
,
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
const
torch
::
Tensor
&
combined_rdma_head
,
const
torch
::
Tensor
&
combined_nvl_head
,
...
...
csrc/kernels/api.cuh
View file @
bd429ffe
...
@@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
...
@@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
void
combine
(
cudaDataType_t
type
,
void
combine
(
cudaDataType_t
type
,
void
*
recv_x
,
float
*
recv_topk_weights
,
void
*
recv_x
,
float
*
recv_topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
...
@@ -121,6 +122,7 @@ void combine(cudaDataType_t type,
...
@@ -121,6 +122,7 @@ void combine(cudaDataType_t type,
void
*
combined_x
,
float
*
combined_topk_weights
,
void
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
bool
*
is_combined_token_in_rank
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
...
...
csrc/kernels/internode.cu
View file @
bd429ffe
...
@@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
...
@@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
is_cached_dispatch
,
cpu_rdma_team
);
is_cached_dispatch
,
cpu_rdma_team
);
}
}
template
<
int
kNumRanks
,
typename
dtype_t
,
int
kMaxNumRanks
,
typename
ReceiveFn
,
typename
ReceiveTWFn
>
template
<
int
kNumRanks
,
bool
kMaybeWithBias
,
typename
dtype_t
,
int
kMaxNumRanks
,
typename
ReceiveFn
,
typename
ReceiveTWFn
>
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
__device__
int
combine_token
(
bool
is_token_in_rank
,
int
head_idx
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int
lane_id
,
int
hidden_int4
,
int
num_topk
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
int4
*
combined_row
,
float
*
combined_topk_weights
,
const
int4
*
bias_0_int4
,
const
int4
*
bias_1_int4
,
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
int
num_max_recv_tokens
,
const
ReceiveFn
&
recv_fn
,
const
ReceiveTWFn
&
recv_tw_fn
)
{
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
constexpr
auto
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
...
@@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
...
@@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
// Reduce data
// Reduce data
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
32
)
{
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
32
)
{
// Read bias
// TODO: make it as a finer-grained template
int4
bias_0_value_int4
,
bias_1_value_int4
;
if
(
kMaybeWithBias
)
{
bias_0_value_int4
=
bias_0_int4
!=
nullptr
?
ld_nc_global
(
bias_0_int4
+
i
)
:
make_int4
(
0
,
0
,
0
,
0
);
bias_1_value_int4
=
bias_1_int4
!=
nullptr
?
ld_nc_global
(
bias_1_int4
+
i
)
:
make_int4
(
0
,
0
,
0
,
0
);
}
// Read buffers
// Read buffers
// TODO: maybe too many registers here
// TODO: maybe too many registers here
int4
recv_value_int4
[
kMaxNumRanks
];
int4
recv_value_int4
[
kMaxNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
recv_value_int4
[
j
]
=
recv_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
recv_value_int4
[
j
]
=
recv_fn
(
topk_ranks
[
j
],
slot_indices
[
j
],
i
);
// Clean
// Reduce bias
float
values
[
kDtypePerInt4
]
=
{
0
};
if
(
kMaybeWithBias
)
{
auto
bias_0_values
=
reinterpret_cast
<
const
dtype_t
*>
(
&
bias_0_value_int4
);
auto
bias_1_values
=
reinterpret_cast
<
const
dtype_t
*>
(
&
bias_1_value_int4
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
values
[
j
]
=
static_cast
<
float
>
(
bias_0_values
[
j
])
+
static_cast
<
float
>
(
bias_1_values
[
j
]);
}
// Reduce all-to-all results
// Reduce all-to-all results
float
values
[
kDtypePerInt4
]
=
{
0
};
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
auto
recv_value_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value_int4
[
j
]);
auto
recv_value_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value_int4
[
j
]);
...
@@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
...
@@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
combine
(
int4
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
bool
*
is_combined_token_in_rank
,
const
int4
*
x
,
const
float
*
topk_weights
,
const
int4
*
x
,
const
float
*
topk_weights
,
const
int4
*
bias_0
,
const
int4
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
SourceMeta
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
SourceMeta
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
...
@@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights,
...
@@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights,
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
void
*
shifted
=
send_buffer
+
rdma_slot_idx
*
num_bytes_per_rdma_token
;
auto
recv_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
);
};
auto
recv_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
nvl_channel_x
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
hidden_int4
+
hidden_int4_idx
);
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
auto
recv_tw_fn
=
[
&
](
int
src_nvl_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
nvl_channel_topk_weights
.
buffer
(
src_nvl_rank
)
+
slot_idx
*
num_topk
+
topk_idx
);
};
combine_token
<
NUM_MAX_NVL_PEERS
,
dtype_t
,
NUM_MAX_NVL_PEERS
>
(
expected_head
>=
0
,
combine_token
<
NUM_MAX_NVL_PEERS
,
false
,
dtype_t
,
NUM_MAX_NVL_PEERS
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
static_cast
<
int4
*>
(
shifted
),
static_cast
<
int4
*>
(
shifted
),
reinterpret_cast
<
float
*>
(
static_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
reinterpret_cast
<
float
*>
(
static_cast
<
int8_t
*>
(
shifted
)
+
hidden_bytes
+
sizeof
(
SourceMeta
)),
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
nullptr
,
nullptr
,
num_max_nvl_chunked_recv_tokens_per_rdma
,
recv_fn
,
recv_tw_fn
);
// Update head
// Update head
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
if
(
lane_id
<
NUM_MAX_NVL_PEERS
)
...
@@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights,
...
@@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights,
// Combine current token
// Combine current token
auto
recv_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
);};
auto
recv_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
hidden_int4_idx
)
->
int4
{
return
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
)
+
hidden_int4_idx
);};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
auto
recv_tw_fn
=
[
&
](
int
src_rdma_rank
,
int
slot_idx
,
int
topk_idx
)
->
float
{
return
ld_nc_global
(
reinterpret_cast
<
const
float
*>
(
rdma_channel_data
.
recv_buffer
(
src_rdma_rank
)
+
slot_idx
*
num_bytes_per_rdma_token
+
hidden_bytes
+
sizeof
(
SourceMeta
))
+
topk_idx
);};
combine_token
<
kNumRDMARanks
,
dtype_t
,
kNumTopkRDMARanks
>
(
expected_head
>=
0
,
combine_token
<
kNumRDMARanks
,
true
,
dtype_t
,
kNumTopkRDMARanks
>
(
expected_head
>=
0
,
expected_head
,
lane_id
,
expected_head
,
lane_id
,
hidden_int4
,
num_topk
,
hidden_int4
,
num_topk
,
combined_x
+
token_idx
*
hidden_int4
,
combined_x
+
token_idx
*
hidden_int4
,
combined_topk_weights
+
token_idx
*
num_topk
,
combined_topk_weights
+
token_idx
*
num_topk
,
bias_0
==
nullptr
?
nullptr
:
bias_0
+
token_idx
*
hidden_int4
,
bias_1
==
nullptr
?
nullptr
:
bias_1
+
token_idx
*
hidden_int4
,
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
num_max_rdma_chunked_recv_tokens
,
recv_fn
,
recv_tw_fn
);
}
}
...
@@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type,
...
@@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type,
void
*
combined_x
,
float
*
combined_topk_weights
,
void
*
combined_x
,
float
*
combined_topk_weights
,
const
bool
*
is_combined_token_in_rank
,
const
bool
*
is_combined_token_in_rank
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
int
*
combined_rdma_head
,
const
int
*
combined_nvl_head
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
const
void
*
src_meta
,
const
int
*
rdma_channel_prefix_matrix
,
const
int
*
rdma_rank_prefix_sum
,
const
int
*
gbl_channel_prefix_matrix
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
int
num_tokens
,
int
num_combined_tokens
,
int
hidden
,
int
num_topk
,
...
@@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type,
...
@@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type,
LAUNCH_KERNEL(&cfg, combine_func, \
LAUNCH_KERNEL(&cfg, combine_func, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4*>(x), topk_weights, \
reinterpret_cast<const int4*>(bias_0), reinterpret_cast<const int4*>(bias_1), \
combined_rdma_head, combined_nvl_head, \
combined_rdma_head, combined_nvl_head, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
num_tokens, num_combined_tokens, hidden, num_topk, \
num_tokens, num_combined_tokens, hidden, num_topk, \
...
...
csrc/kernels/intranode.cu
View file @
bd429ffe
...
@@ -587,6 +587,7 @@ template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWa
...
@@ -587,6 +587,7 @@ template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWa
__global__
void
__launch_bounds__
(
kNumThreads
,
1
)
__global__
void
__launch_bounds__
(
kNumThreads
,
1
)
combine
(
dtype_t
*
recv_x
,
float
*
recv_topk_weights
,
combine
(
dtype_t
*
recv_x
,
float
*
recv_topk_weights
,
const
dtype_t
*
x
,
const
float
*
topk_weights
,
const
dtype_t
*
x
,
const
float
*
topk_weights
,
const
dtype_t
*
bias_0
,
const
dtype_t
*
bias_1
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
void
**
buffer_ptrs
,
int
rank
,
void
**
buffer_ptrs
,
int
rank
,
...
@@ -602,6 +603,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
...
@@ -602,6 +603,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
constexpr
int
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
constexpr
int
kDtypePerInt4
=
sizeof
(
int4
)
/
sizeof
(
dtype_t
);
int
hidden_int4
=
hidden
*
sizeof
(
dtype_t
)
/
sizeof
(
int4
);
int
hidden_int4
=
hidden
*
sizeof
(
dtype_t
)
/
sizeof
(
int4
);
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
);
auto
x_int4
=
reinterpret_cast
<
const
int4
*>
(
x
);
auto
bias_0_int4
=
reinterpret_cast
<
const
int4
*>
(
bias_0
);
auto
bias_1_int4
=
reinterpret_cast
<
const
int4
*>
(
bias_1
);
auto
recv_int4
=
reinterpret_cast
<
int4
*>
(
recv_x
);
auto
recv_int4
=
reinterpret_cast
<
int4
*>
(
recv_x
);
// TMA stuffs
// TMA stuffs
...
@@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
...
@@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
EP_STATIC_ASSERT
(
kNumStages
*
32
*
sizeof
(
int4
)
<=
kNumTMABytesPerWarp
,
"Invalid count"
);
EP_STATIC_ASSERT
(
kNumStages
*
32
*
sizeof
(
int4
)
<=
kNumTMABytesPerWarp
,
"Invalid count"
);
#pragma unroll
#pragma unroll
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
32
)
{
for
(
int
i
=
lane_id
;
i
<
hidden_int4
;
i
+=
32
)
{
// Read bias
// TODO: make it as a template
int4
bias_0_value_int4
=
bias_0_int4
!=
nullptr
?
__ldg
(
bias_0_int4
+
token_idx
*
hidden_int4
+
i
)
:
make_int4
(
0
,
0
,
0
,
0
);
int4
bias_1_value_int4
=
bias_1_int4
!=
nullptr
?
__ldg
(
bias_1_int4
+
token_idx
*
hidden_int4
+
i
)
:
make_int4
(
0
,
0
,
0
,
0
);
// Read buffers
// Read buffers
int4
recv_value_int4
[
kNumRanks
];
int4
recv_value_int4
[
kNumRanks
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
recv_value_int4
[
j
]
=
ld_nc_global
(
channel_x_buffers
[
topk_ranks
[
j
]].
buffer
()
+
slot_indices
[
j
]
*
hidden_int4
+
i
);
recv_value_int4
[
j
]
=
ld_nc_global
(
channel_x_buffers
[
topk_ranks
[
j
]].
buffer
()
+
slot_indices
[
j
]
*
hidden_int4
+
i
);
// Reduce bias
float
values
[
kDtypePerInt4
];
auto
bias_0_values
=
reinterpret_cast
<
const
dtype_t
*>
(
&
bias_0_value_int4
);
auto
bias_1_values
=
reinterpret_cast
<
const
dtype_t
*>
(
&
bias_1_value_int4
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kDtypePerInt4
;
++
j
)
values
[
j
]
=
static_cast
<
float
>
(
bias_0_values
[
j
])
+
static_cast
<
float
>
(
bias_1_values
[
j
]);
// Reduce all-to-all results
// Reduce all-to-all results
float
values
[
kDtypePerInt4
]
=
{
0
};
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_topk_ranks
;
++
j
)
{
auto
recv_value_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value_int4
[
j
]);
auto
recv_value_dtypes
=
reinterpret_cast
<
const
dtype_t
*>
(
&
recv_value_int4
[
j
]);
...
@@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
...
@@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
void
combine
(
cudaDataType_t
type
,
void
combine
(
cudaDataType_t
type
,
void
*
recv_x
,
float
*
recv_topk_weights
,
void
*
recv_x
,
float
*
recv_topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
x
,
const
float
*
topk_weights
,
const
void
*
bias_0
,
const
void
*
bias_1
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
const
int
*
src_idx
,
const
int
*
rank_prefix_matrix
,
const
int
*
channel_prefix_matrix
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
int
*
send_head
,
int
num_tokens
,
int
num_recv_tokens
,
int
hidden
,
int
num_topk
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
void
**
buffer_ptrs
,
int
rank
,
int
num_ranks
,
...
@@ -904,6 +920,7 @@ void combine(cudaDataType_t type,
...
@@ -904,6 +920,7 @@ void combine(cudaDataType_t type,
LAUNCH_KERNEL(&cfg, kernel, \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
reinterpret_cast<const dtype*>(bias_0), reinterpret_cast<const dtype*>(bias_1), \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \
buffer_ptrs, rank, \
...
...
deep_ep/buffer.py
View file @
bd429ffe
...
@@ -176,6 +176,16 @@ class Buffer:
...
@@ -176,6 +176,16 @@ class Buffer:
assert
tensor
.
numel
()
>=
size
.
numel
()
assert
tensor
.
numel
()
>=
size
.
numel
()
return
tensor
[:
size
.
numel
()].
view
(
size
)
return
tensor
[:
size
.
numel
()].
view
(
size
)
@
staticmethod
def
_unpack_bias
(
bias
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]):
bias_0
,
bias_1
=
None
,
None
if
isinstance
(
bias
,
torch
.
Tensor
):
bias_0
=
bias
elif
isinstance
(
bias
,
tuple
):
assert
len
(
bias
)
==
2
bias_0
,
bias_1
=
bias
return
bias_0
,
bias_1
@
staticmethod
@
staticmethod
def
get_dispatch_config
(
num_ranks
:
int
)
->
Config
:
def
get_dispatch_config
(
num_ranks
:
int
)
->
Config
:
"""
"""
...
@@ -346,6 +356,7 @@ class Buffer:
...
@@ -346,6 +356,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
combine
(
self
,
x
:
torch
.
Tensor
,
handle
:
Tuple
,
def
combine
(
self
,
x
:
torch
.
Tensor
,
handle
:
Tuple
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
config
:
Optional
[
Config
]
=
None
,
config
:
Optional
[
Config
]
=
None
,
previous_event
:
Optional
[
EventOverlap
]
=
None
,
async_finish
:
bool
=
False
,
previous_event
:
Optional
[
EventOverlap
]
=
None
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
)
->
\
allocate_on_comm_stream
:
bool
=
False
)
->
\
...
@@ -376,14 +387,15 @@ class Buffer:
...
@@ -376,14 +387,15 @@ class Buffer:
# Internode
# Internode
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
:
if
self
.
runtime
.
get_num_rdma_ranks
()
>
1
:
return
self
.
internode_combine
(
x
,
handle
,
topk_weights
,
config
,
previous_event
,
async_finish
,
allocate_on_comm_stream
)
return
self
.
internode_combine
(
x
,
handle
,
topk_weights
,
bias
,
config
,
previous_event
,
async_finish
,
allocate_on_comm_stream
)
# NOTES: the second `_` is for the sending side, so we should use the third one
# NOTES: the second `_` is for the sending side, so we should use the third one
rank_prefix_matrix
,
_
,
channel_prefix_matrix
,
src_idx
,
is_recv_token_in_rank
,
send_head
=
handle
rank_prefix_matrix
,
_
,
channel_prefix_matrix
,
src_idx
,
is_recv_token_in_rank
,
send_head
=
handle
bias_0
,
bias_1
=
Buffer
.
_unpack_bias
(
bias
)
# Launch the kernel
# Launch the kernel
recv_x
,
recv_topk_weights
,
event
=
self
.
runtime
.
intranode_combine
(
recv_x
,
recv_topk_weights
,
event
=
self
.
runtime
.
intranode_combine
(
x
,
topk_weights
,
x
,
topk_weights
,
bias_0
,
bias_1
,
src_idx
,
rank_prefix_matrix
,
channel_prefix_matrix
,
send_head
,
config
,
src_idx
,
rank_prefix_matrix
,
channel_prefix_matrix
,
send_head
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
getattr
(
previous_event
,
'event'
,
None
),
async_finish
,
allocate_on_comm_stream
)
return
recv_x
,
recv_topk_weights
,
EventOverlap
(
event
)
return
recv_x
,
recv_topk_weights
,
EventOverlap
(
event
)
...
@@ -442,6 +454,7 @@ class Buffer:
...
@@ -442,6 +454,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
internode_combine
(
self
,
x
:
torch
.
Tensor
,
handle
:
Union
[
tuple
,
list
],
def
internode_combine
(
self
,
x
:
torch
.
Tensor
,
handle
:
Union
[
tuple
,
list
],
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
config
:
Optional
[
Config
]
=
None
,
config
:
Optional
[
Config
]
=
None
,
previous_event
:
Optional
[
EventOverlap
]
=
None
,
async_finish
:
bool
=
False
,
previous_event
:
Optional
[
EventOverlap
]
=
None
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
)
->
\
allocate_on_comm_stream
:
bool
=
False
)
->
\
...
@@ -452,15 +465,16 @@ class Buffer:
...
@@ -452,15 +465,16 @@ class Buffer:
"""
"""
assert
config
is
not
None
assert
config
is
not
None
# Unpack handle
# Unpack handle
and bias
is_combined_token_in_rank
,
\
is_combined_token_in_rank
,
\
_
,
_
,
\
_
,
_
,
\
rdma_channel_prefix_matrix
,
rdma_rank_prefix_sum
,
gbl_channel_prefix_matrix
,
gbl_rank_prefix_sum
,
\
rdma_channel_prefix_matrix
,
rdma_rank_prefix_sum
,
gbl_channel_prefix_matrix
,
gbl_rank_prefix_sum
,
\
src_meta
,
send_rdma_head
,
send_nvl_head
=
handle
src_meta
,
send_rdma_head
,
send_nvl_head
=
handle
bias_0
,
bias_1
=
Buffer
.
_unpack_bias
(
bias
)
# Launch the kernel
# Launch the kernel
combined_x
,
combined_topk_weights
,
event
=
self
.
runtime
.
internode_combine
(
combined_x
,
combined_topk_weights
,
event
=
self
.
runtime
.
internode_combine
(
x
,
topk_weights
,
x
,
topk_weights
,
bias_0
,
bias_1
,
src_meta
,
is_combined_token_in_rank
,
src_meta
,
is_combined_token_in_rank
,
rdma_channel_prefix_matrix
,
rdma_rank_prefix_sum
,
gbl_channel_prefix_matrix
,
rdma_channel_prefix_matrix
,
rdma_rank_prefix_sum
,
gbl_channel_prefix_matrix
,
send_rdma_head
,
send_nvl_head
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
send_rdma_head
,
send_nvl_head
,
config
,
getattr
(
previous_event
,
'event'
,
None
),
...
...
tests/test_internode.py
View file @
bd429ffe
...
@@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
...
@@ -140,14 +140,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
check_data
(
recv_x
,
recv_gbl_rank_prefix_sum
)
# Test combine
# Test combine
combine_args
=
{
'x'
:
recv_x
,
'handle'
:
handle
,
'config'
:
config
,
'async_finish'
:
async_mode
}
bias_0
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
bias_1
=
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
combine_args
=
{
'x'
:
recv_x
,
'bias'
:
(
bias_0
,
bias_1
),
'handle'
:
handle
,
'config'
:
config
,
'async_finish'
:
async_mode
}
if
with_topk
:
if
with_topk
:
combine_args
.
update
({
'topk_weights'
:
recv_topk_weights
})
combine_args
.
update
({
'topk_weights'
:
recv_topk_weights
})
if
previous_mode
:
if
previous_mode
:
combine_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
combine_args
.
update
({
'previous_event'
:
buffer
.
capture
()})
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
combined_x
,
combined_topk_weights
,
event
=
buffer
.
combine
(
**
combine_args
)
event
.
current_stream_wait
()
if
async_mode
else
()
event
.
current_stream_wait
()
if
async_mode
else
()
check_x
=
combined_x
.
float
()
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
check_x
=
(
combined_x
.
float
()
-
bias_0
.
float
()
-
bias_1
.
float
())
/
is_token_in_rank
.
sum
(
dim
=
1
).
unsqueeze
(
1
)
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
ref_x
=
x_pure_rand
if
current_x
is
x_pure_rand
else
x
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
assert
calc_diff
(
check_x
,
ref_x
)
<
5e-6
if
with_topk
:
if
with_topk
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment