Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ec322e4b
Commit
ec322e4b
authored
Jan 10, 2021
by
Rick Ho
Browse files
global scatter gather kernels and pytorch C function
parent
7a2ad4a1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
253 additions
and
144 deletions
+253
-144
pytorch/cuda/comm_manager.cpp
pytorch/cuda/comm_manager.cpp
+0
-13
pytorch/cuda/comm_manager.h
pytorch/cuda/comm_manager.h
+0
-34
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+11
-0
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+18
-0
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+30
-23
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+5
-2
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+138
-68
pytorch/cuda/moe_cuda_kernel.h
pytorch/cuda/moe_cuda_kernel.h
+46
-0
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+2
-2
pytorch/cuda/run.sh
pytorch/cuda/run.sh
+1
-1
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+2
-1
No files found.
pytorch/cuda/comm_manager.cpp
deleted
100644 → 0
View file @
7a2ad4a1
#include "comm_manager.h"
CommManager
*
comm_mgr
=
0
;
CommManager
*
getCommManager
()
{
if
(
!
comm_mgr
)
{
comm_mgr
=
new
CommManager
();
}
return
comm_mgr
;
}
pytorch/cuda/comm_manager.h
deleted
100644 → 0
View file @
7a2ad4a1
#ifndef COMM_MANAGER_H
#define COMM_MANAGER_H
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#include <mpi.h>
#include "nccl.h"
struct
CommManager
{
int
rank
,
size
;
ncclComm_t
ncclcomm
;
CommManager
()
{
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank
);
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
size
);
ncclUniqueId
uid
;
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
uid
);
}
MPI_Bcast
(
&
uid
,
sizeof
(
uid
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
);
NCCL_SAFE_CALL
(
ncclCommInitRank
(
&
ncclcomm
,
size
,
uid
,
rank
));
}
};
CommManager
*
getCommManager
();
#endif // COMM_MANAGER
pytorch/cuda/cuda_stream_manager.cpp
View file @
ec322e4b
...
@@ -32,6 +32,17 @@ void CudaStreamManager::setup(const int device) {
...
@@ -32,6 +32,17 @@ void CudaStreamManager::setup(const int device) {
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
#ifdef MOE_USE_NCCL
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
rank
);
MPI_Comm_size
(
MPI_COMM_WORLD
,
&
size
);
ncclUniqueId
uid
;
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
uid
);
}
MPI_Bcast
(
&
uid
,
sizeof
(
uid
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
);
NCCL_SAFE_CALL
(
ncclCommInitRank
(
&
ncclcomm
,
size
,
uid
,
rank
));
#endif
}
}
void
CudaStreamManager
::
destroy
()
{
void
CudaStreamManager
::
destroy
()
{
...
...
pytorch/cuda/cuda_stream_manager.h
View file @
ec322e4b
...
@@ -4,11 +4,29 @@
...
@@ -4,11 +4,29 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublas_v2.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
class
CudaStreamManager
{
class
CudaStreamManager
{
public:
public:
int
device
;
int
device
;
cublasHandle_t
*
handles
;
cublasHandle_t
*
handles
;
cudaStream_t
*
streams
;
cudaStream_t
*
streams
;
#ifdef MOE_USE_NCCL
int
rank
,
size
;
ncclComm_t
ncclcomm
;
#endif
public:
public:
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
CudaStreamManager
(
int
device_
)
:
device
(
device_
)
{
...
...
pytorch/cuda/moe.cpp
View file @
ec322e4b
...
@@ -4,29 +4,7 @@
...
@@ -4,29 +4,7 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
#include "moe_cuda_kernel.h"
torch
::
Tensor
gate
,
size_t
num_expert
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
...
@@ -87,6 +65,31 @@ std::vector<torch::Tensor> moe_backward(
...
@@ -87,6 +65,31 @@ std::vector<torch::Tensor> moe_backward(
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
weight
,
expert_count
);
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
weight
,
expert_count
);
}
}
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
#endif
/*
/*
int main() {
int main() {
...
@@ -103,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -103,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_count"
,
&
moe_expert_count
,
"MoE expert count (CUDA)"
);
m
.
def
(
"expert_count"
,
&
moe_expert_count
,
"MoE expert count (CUDA)"
);
m
.
def
(
"local_scatter"
,
&
moe_local_scatter
,
"MoE local scatter (CUDA)"
);
m
.
def
(
"local_scatter"
,
&
moe_local_scatter
,
"MoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
moe_local_gather
,
"MoE local gather (CUDA)"
);
m
.
def
(
"local_gather"
,
&
moe_local_gather
,
"MoE local gather (CUDA)"
);
#ifdef MOE_USE_NCCL
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
}
pytorch/cuda/moe.py
View file @
ec322e4b
...
@@ -102,8 +102,11 @@ def test():
...
@@ -102,8 +102,11 @@ def test():
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
Tensor
([
0
,
1
,
0
,
1
]).
int
().
cuda
()
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
ec322e4b
#include
<torch/extension
.h
>
#include
"moe_cuda_kernel
.h
"
#include <torch/torch.h>
#include <cstdio>
#include <cstdio>
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
...
@@ -10,13 +10,16 @@
...
@@ -10,13 +10,16 @@
#include <helper_cuda.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <mpi.h>
#include <nccl.h>
#endif
#include "timer.hh"
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
#include "comm_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
...
@@ -79,79 +82,146 @@ void moe_cuda_expert_count_impl(
...
@@ -79,79 +82,146 @@ void moe_cuda_expert_count_impl(
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
void
moe_cuda_global_scatter
()
{
template
<
typename
scalar_t
>
if
(
cm
->
size
>
1
)
{
void
moe_cuda_global_scatter_impl
(
if
(
expert_sz
)
{
const
scalar_t
*
local_input_buf
,
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
const
int
*
local_expert_count
,
sizeof
(
scalar_t
)
*
expert_sz
*
in_feat
));
const
int
*
global_expert_count
,
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
scalar_t
*
input_buf
,
sizeof
(
scalar_t
)
*
expert_sz
*
out_feat
));
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
}
CudaStreamManager
*
smgr
)
{
int
recv_ptr
=
0
;
// assert world_size > 1
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
int
recv_ptr
=
0
;
NCCL_SAFE_CALL
(
ncclGroupStart
());
/* TODO: may save for backward */
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
int
idx
=
i
+
j
*
num_expert
;
expert_ptr
[
0
]
=
0
;
if
(
expert_count
[
idx
])
{
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
}
expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
j
,
NCCL_SAFE_CALL
(
ncclGroupStart
());
cm
->
ncclcomm
,
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
h
->
getStream
(
0
)));
int
idx
=
i
+
j
*
num_expert
;
}
if
(
local_expert_count
[
idx
])
{
if
(
all_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
NCCL_SAFE_CALL
(
ncclRecv
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
input_buf
+
recv_ptr
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
all_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
cm
->
ncclcomm
,
smgr
->
stream
(
0
)));
h
->
getStream
(
0
)));
}
recv_ptr
+=
all_expert_count
[
idx
];
if
(
global_expert_count
[
idx
])
{
}
NCCL_SAFE_CALL
(
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
global_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
recv_ptr
+=
global_expert_count
[
idx
];
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
}
else
{
NCCL_SAFE_CALL
(
ncclGroupEnd
());
input_buf
=
local_input_buf
;
output_buf
=
local_output_buf
;
}
}
delete
[]
expert_ptr
;
}
}
void
moe_cuda_global_gather
()
{
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
if
(
cm
->
size
>
1
)
{
torch
::
Tensor
input_buf
,
int
send_ptr
=
0
;
torch
::
Tensor
local_expert_count
,
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
torch
::
Tensor
global_expert_count
,
NCCL_SAFE_CALL
(
ncclGroupStart
());
long
batch_size
,
long
n_workers
)
{
for
(
int
j
=
0
;
j
<
cm
->
size
;
++
j
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
int
idx
=
i
+
j
*
num_expert
;
auto
in_feat
=
input_buf
.
size
(
1
);
if
(
all_expert_count
[
idx
])
{
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
NCCL_SAFE_CALL
(
ncclSend
(
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
output_buf
+
send_ptr
*
out_feat
,
all_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
ncclChar
,
"moe_cuda_global_scatter"
,
([
&
]
{
j
,
moe_cuda_global_scatter_impl
<
scalar_t
>
(
cm
->
ncclcomm
,
input_buf
.
data_ptr
<
scalar_t
>
(),
h
->
getStream
(
0
)));
local_expert_count
.
data_ptr
<
int
>
(),
send_ptr
+=
all_expert_count
[
idx
];
global_expert_count
.
data_ptr
<
int
>
(),
}
global_input_buf
.
data_ptr
<
scalar_t
>
(),
if
(
expert_count
[
idx
])
{
in_feat
,
num_expert
,
n_workers
,
NCCL_SAFE_CALL
(
ncclRecv
(
smgr
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
);
expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
}));
ncclChar
,
return
{
global_input_buf
,};
j
,
}
cm
->
ncclcomm
,
h
->
getStream
(
0
)));
template
<
typename
scalar_t
>
}
void
moe_cuda_global_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
int
send_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
global_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
send_ptr
+=
global_expert_count
[
idx
];
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
out_feat
=
output_buf
.
size
(
1
);
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
local_output_buf
,};
}
}
#endif // MOE_USE_NCCL
#endif // MOE_USE_NCCL
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -159,8 +229,8 @@ void moe_cuda_local_scatter_impl(
...
@@ -159,8 +229,8 @@ void moe_cuda_local_scatter_impl(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
const
int
*
d_pos
,
const
int
*
d_pos
,
scalar_t
*
input_buf
,
scalar_t
*
input_buf
,
const
size_t
batch_size
,
const
long
batch_size
,
const
size_t
in_feat
,
const
long
in_feat
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
...
...
pytorch/cuda/moe_cuda_kernel.h
0 → 100644
View file @
ec322e4b
#ifndef MOE_CUDA_KERNEL_H
#define MOE_CUDA_KERNEL_H
#include <vector>
#include <torch/extension.h>
#include <torch/torch.h>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
#endif
#endif // MOE_CUDA_KERNEL_H
pytorch/cuda/moe_test.py
View file @
ec322e4b
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
import
sys
import
sys
dev_name
=
'cuda:
0
'
dev_name
=
'cuda:
1
'
def
perf
():
def
perf
():
...
@@ -16,7 +16,7 @@ def perf():
...
@@ -16,7 +16,7 @@ def perf():
out_feat
=
int
(
sys
.
argv
[
3
])
out_feat
=
int
(
sys
.
argv
[
3
])
num_expert
=
int
(
sys
.
argv
[
4
])
num_expert
=
int
(
sys
.
argv
[
4
])
inp
=
torch
.
rand
(
batch_size
,
i
o
_feat
).
cuda
(
dev_name
)
inp
=
torch
.
rand
(
batch_size
,
i
n
_feat
).
cuda
(
dev_name
)
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
high
=
num_expert
*
torch
.
distributed
.
get_world_size
(),
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
(
dev_name
)
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
(
dev_name
)
...
...
pytorch/cuda/run.sh
View file @
ec322e4b
...
@@ -26,5 +26,5 @@ then
...
@@ -26,5 +26,5 @@ then
done
done
done
done
else
else
python3
$@
2>logs/
$OMPI_COMM_WORLD_RANK
.log
python3
$@
#
2>logs/$OMPI_COMM_WORLD_RANK.log
fi
fi
pytorch/cuda/setup.py
View file @
ec322e4b
...
@@ -12,15 +12,16 @@ setup(
...
@@ -12,15 +12,16 @@ setup(
sources
=
[
sources
=
[
'moe.cpp'
,
'moe.cpp'
,
'cuda_stream_manager.cpp'
,
'cuda_stream_manager.cpp'
,
'comm_manager.cpp'
,
'moe_cuda_kernel.cu'
,
'moe_cuda_kernel.cu'
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
[
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
),
'-I{}'
.
format
(
CUDA_HELPER
),
'-DMOE_USE_NCCL'
],
],
'nvcc'
:
[
'nvcc'
:
[
'-I{}'
.
format
(
CUDA_HELPER
),
'-I{}'
.
format
(
CUDA_HELPER
),
'-DMOE_USE_NCCL'
]
]
}
}
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment