Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
7a2ad4a1
"src/vscode:/vscode.git/clone" did not exist on "f8ed456e79c726a490b5223d9ecf4bcbc1811648"
Commit
7a2ad4a1
authored
Jan 10, 2021
by
Rick Ho
Browse files
use expert_count array instead of expert_n
parent
307e0ad9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
36 deletions
+40
-36
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+40
-36
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
7a2ad4a1
...
@@ -58,7 +58,7 @@ void moe_cuda_expert_count_impl(
...
@@ -58,7 +58,7 @@ void moe_cuda_expert_count_impl(
++
expert_count
[
gate
[
i
]];
++
expert_count
[
gate
[
i
]];
}
}
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
tot
_expert
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num
_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
}
...
@@ -67,7 +67,7 @@ void moe_cuda_expert_count_impl(
...
@@ -67,7 +67,7 @@ void moe_cuda_expert_count_impl(
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
}
for
(
int
i
=
tot
_expert
-
1
;
i
>
0
;
--
i
)
{
for
(
int
i
=
num
_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
}
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
...
@@ -77,6 +77,8 @@ void moe_cuda_expert_count_impl(
...
@@ -77,6 +77,8 @@ void moe_cuda_expert_count_impl(
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
}
}
#ifdef MOE_USE_NCCL
void
moe_cuda_global_scatter
()
{
void
moe_cuda_global_scatter
()
{
if
(
cm
->
size
>
1
)
{
if
(
cm
->
size
>
1
)
{
if
(
expert_sz
)
{
if
(
expert_sz
)
{
...
@@ -118,6 +120,40 @@ void moe_cuda_global_scatter() {
...
@@ -118,6 +120,40 @@ void moe_cuda_global_scatter() {
}
}
}
}
void
moe_cuda_global_gather
()
{
if
(
cm
->
size
>
1
)
{
int
send_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
all_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
all_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
)));
send_ptr
+=
all_expert_count
[
idx
];
}
if
(
expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
#endif // MOE_USE_NCCL
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
...
@@ -170,7 +206,7 @@ void moe_cuda_forward_impl(
...
@@ -170,7 +206,7 @@ void moe_cuda_forward_impl(
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_
n
[
i
]
==
0
)
{
if
(
expert_
count
[
i
]
==
0
)
{
continue
;
continue
;
}
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
...
@@ -240,44 +276,12 @@ void moe_cuda_backward_impl(
...
@@ -240,44 +276,12 @@ void moe_cuda_backward_impl(
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
));
ptr
+=
expert_
n
[
i
];
ptr
+=
expert_
count
[
i
];
}
}
smgr
->
sync
(
num_expert
);
smgr
->
sync
(
num_expert
);
}
}
void
moe_cuda_global_gather
()
{
if
(
cm
->
size
>
1
)
{
int
send_ptr
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
all_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
all_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
)));
send_ptr
+=
all_expert_count
[
idx
];
}
if
(
expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
cm
->
ncclcomm
,
h
->
getStream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
size_t
num_expert
)
{
size_t
num_expert
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment