Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6cb6bbe4
Commit
6cb6bbe4
authored
Apr 27, 2021
by
Rick Ho
Browse files
global exchange update variable name
parent
bb92d30e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
21 deletions
+21
-21
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+6
-6
cuda/global_exchange.h
cuda/global_exchange.h
+15
-15
No files found.
cuda/global_exchange.cpp
View file @
6cb6bbe4
...
...
@@ -7,14 +7,14 @@
std
::
vector
<
torch
::
Tensor
>
_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
n
um
_expert
,
long
n_workers
)
{
long
n_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
smgr
=
getCudaStreamManager
(
local_expert_count
.
device
().
index
());
fmoe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
n
um
_expert
,
n_workers
,
n_expert
,
n_workers
,
smgr
);
return
{
global_expert_count
};
}
...
...
@@ -26,7 +26,7 @@ std::vector<torch::Tensor> _global_scatter(
long
batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
auto
n
um
_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
n_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
in_feat
=
input_buf
.
size
(
1
);
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
...
...
@@ -38,7 +38,7 @@ std::vector<torch::Tensor> _global_scatter(
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
in_feat
,
n
um
_expert
,
n_workers
,
in_feat
,
n_expert
,
n_workers
,
smgr
);
}));
...
...
@@ -52,7 +52,7 @@ std::vector<torch::Tensor> _global_gather(
long
batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
output_buf
);
auto
n
um
_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
n_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
out_feat
=
output_buf
.
size
(
1
);
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
...
...
@@ -64,7 +64,7 @@ std::vector<torch::Tensor> _global_gather(
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
n
um
_expert
,
n_workers
,
out_feat
,
n_expert
,
n_workers
,
smgr
);
}));
...
...
cuda/global_exchange.h
View file @
6cb6bbe4
...
...
@@ -4,20 +4,20 @@
void
fmoe_cuda_expert_exchange_impl
(
const
long
*
local_expert_count
,
long
*
global_expert_count
,
int
n
um
_expert
,
int
world_size
,
int
n_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
n
um
_expert
*
i
,
n
um
_expert
,
local_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n
um
_expert
*
i
,
n
um
_expert
,
global_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
...
...
@@ -33,21 +33,21 @@ void fmoe_cuda_global_scatter_impl(
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
input_buf
,
size_t
in_feat
,
size_t
n
um
_expert
,
size_t
world_size
,
size_t
in_feat
,
size_t
n_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
// assert world_size > 1
int
recv_ptr
=
0
;
/* TODO: may save for backward */
long
*
expert_ptr
=
new
long
[
n
um
_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
n_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
size_t
i
=
1
;
i
<
n
um
_expert
*
world_size
;
++
i
)
{
for
(
size_t
i
=
1
;
i
<
n_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
size_t
i
=
0
;
i
<
n
um
_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
n_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
size_t
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
n
um
_expert
;
int
idx
=
i
+
j
*
n_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
...
...
@@ -80,20 +80,20 @@ void fmoe_cuda_global_gather_impl(
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
n
um
_expert
,
size_t
world_size
,
size_t
out_feat
,
size_t
n_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
long
send_ptr
=
0
;
/* TODO: may save for backward */
long
*
expert_ptr
=
new
long
[
n
um
_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
n_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
size_t
i
=
1
;
i
<
n
um
_expert
*
world_size
;
++
i
)
{
for
(
size_t
i
=
1
;
i
<
n_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
size_t
i
=
0
;
i
<
n
um
_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
n_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
size_t
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
n
um
_expert
;
int
idx
=
i
+
j
*
n_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment