Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
7c3e5149
Commit
7c3e5149
authored
Jan 17, 2021
by
Rick Ho
Browse files
separate kernel files
parent
969ef607
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
214 additions
and
182 deletions
+214
-182
pytorch/cuda/moe_comm_kernel.cu
pytorch/cuda/moe_comm_kernel.cu
+189
-0
pytorch/cuda/moe_compute_kernel.cu
pytorch/cuda/moe_compute_kernel.cu
+0
-181
pytorch/cuda/moe_fused_kernel.cu
pytorch/cuda/moe_fused_kernel.cu
+22
-0
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+3
-1
No files found.
pytorch/cuda/moe_comm_kernel.cu
0 → 100644
View file @
7c3e5149
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
void
moe_cuda_expert_exchange_impl
(
const
int
*
local_expert_count
,
int
*
global_expert_count
,
int
*
fwd_expert_count
,
int
num_expert
,
int
world_size
)
{
MPI_Alltoall
(
local_expert_count
,
num_expert
,
MPI_INT
,
global_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
fwd_expert_count
[
i
]
+=
global_expert_count
[
i
+
j
*
num_expert
];
}
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
fwe_options
=
torch
::
TensorOptions
()
.
dtype
(
local_expert_count
.
dtype
());
auto
fwd_expert_count
=
torch
::
zeros
({
num_expert
},
fwe_options
);
moe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
fwd_expert_count
.
data_ptr
<
int
>
(),
num_expert
,
n_workers
);
return
{
global_expert_count
,
fwd_expert_count
};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_scatter_impl
(
const
scalar_t
*
local_input_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
scalar_t
*
input_buf
,
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
// assert world_size > 1
int
recv_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
global_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
recv_ptr
+=
global_expert_count
[
idx
];
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
in_feat
=
input_buf
.
size
(
1
);
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_global_scatter"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
in_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
global_input_buf
,};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
int
send_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
global_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
send_ptr
+=
global_expert_count
[
idx
];
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
out_feat
=
output_buf
.
size
(
1
);
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
local_output_buf
,};
}
#endif
pytorch/cuda/moe_c
uda
_kernel.cu
→
pytorch/cuda/moe_c
ompute
_kernel.cu
View file @
7c3e5149
...
@@ -10,14 +10,8 @@
...
@@ -10,14 +10,8 @@
#include <helper_cuda.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#endif
#include "timer.hh"
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
...
@@ -80,181 +74,6 @@ void moe_cuda_expert_count_impl(
...
@@ -80,181 +74,6 @@ void moe_cuda_expert_count_impl(
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
}
}
#ifdef MOE_USE_NCCL
void
moe_cuda_expert_exchange_impl
(
const
int
*
local_expert_count
,
int
*
global_expert_count
,
int
*
fwd_expert_count
,
int
num_expert
,
int
world_size
)
{
MPI_Alltoall
(
local_expert_count
,
num_expert
,
MPI_INT
,
global_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
fwd_expert_count
[
i
]
+=
global_expert_count
[
i
+
j
*
num_expert
];
}
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
fwe_options
=
torch
::
TensorOptions
()
.
dtype
(
local_expert_count
.
dtype
());
auto
fwd_expert_count
=
torch
::
zeros
({
num_expert
},
fwe_options
);
moe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
fwd_expert_count
.
data_ptr
<
int
>
(),
num_expert
,
n_workers
);
return
{
global_expert_count
,
fwd_expert_count
};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_scatter_impl
(
const
scalar_t
*
local_input_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
scalar_t
*
input_buf
,
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
// assert world_size > 1
int
recv_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
input_buf
+
recv_ptr
*
in_feat
,
global_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
recv_ptr
+=
global_expert_count
[
idx
];
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
in_feat
=
input_buf
.
size
(
1
);
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_global_scatter"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
in_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
global_input_buf
,};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
int
send_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
output_buf
+
send_ptr
*
out_feat
,
global_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
send_ptr
+=
global_expert_count
[
idx
];
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
)
{
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
auto
out_feat
=
output_buf
.
size
(
1
);
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
local_output_buf
,};
}
#endif // MOE_USE_NCCL
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
...
...
pytorch/cuda/moe_fused_kernel.cu
0 → 100644
View file @
7c3e5149
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
// TODO
#endif
pytorch/cuda/setup.py
View file @
7c3e5149
...
@@ -17,7 +17,9 @@ setup(
...
@@ -17,7 +17,9 @@ setup(
sources
=
[
sources
=
[
'moe.cpp'
,
'moe.cpp'
,
'cuda_stream_manager.cpp'
,
'cuda_stream_manager.cpp'
,
'moe_cuda_kernel.cu'
,
'moe_compute_kernel.cu'
,
'moe_comm_kernel.cu'
,
'moe_fused_kernel.cu'
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
'cxx'
:
cxx_flags
,
'cxx'
:
cxx_flags
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment