Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
4f9f77f8
Commit
4f9f77f8
authored
Sep 11, 2023
by
Rick Ho
Browse files
use torchstream everywhere
parent
2bd187cb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
30 additions
and
27 deletions
+30
-27
cuda/balancing.cuh
cuda/balancing.cuh
+2
-4
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+8
-10
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+2
-3
cuda/global_exchange.h
cuda/global_exchange.h
+4
-6
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+2
-4
cuda/parallel_linear.cuh
cuda/parallel_linear.cuh
+2
-0
cuda/stream_manager.cpp
cuda/stream_manager.cpp
+8
-0
cuda/stream_manager.h
cuda/stream_manager.h
+2
-0
No files found.
cuda/balancing.cuh
View file @
4f9f77f8
...
@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
...
@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
n_worker
,
1024
),
n_expert
);
dim3
grid_dim
(
CEIL
(
n_worker
,
1024
),
n_expert
);
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
limit_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
s
tream
(
0
)
>>>
(
limit_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
torchS
tream
()
>>>
(
ec
,
cap
,
eca
,
n_expert
,
n_worker
);
ec
,
cap
,
eca
,
n_expert
,
n_worker
);
smgr
->
sync
(
1
);
}
}
__global__
__global__
...
@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
...
@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
s
tream
(
0
)
>>>
(
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
torchS
tream
()
>>>
(
gate_idx
,
new_gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
gate_idx
,
new_gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
);
);
smgr
->
sync
(
1
);
}
}
cuda/fastermoe/smart_schedule.h
View file @
4f9f77f8
...
@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long
d_model
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
expert_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
expert_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
smgr
->
syncTorch
();
cudaStreamSynchronize
(
torch_stream
);
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
torch
_s
tream
,
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
torch
S
tream
()
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl(
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
torch
_s
tream
);
cudaEventRecord
(
output_torch_ready
[
step
],
smgr
->
torch
S
tream
()
);
}
}
// Compute over shadowed experts
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_shadow
[
si
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_shadow
[
si
]);
FMOE_SWE
(
torch
_s
tream
,
evt_shadow
[
si
]);
FMOE_SWE
(
smgr
->
torch
S
tream
()
,
evt_shadow
[
si
]);
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
long
offset
=
local_ptr
[
i
];
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
...
@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long
d_model
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
auto
torch_stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
smgr
->
syncTorch
();
cudaStreamSynchronize
(
torch_stream
);
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
...
@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
input_ready
[
step
]);
FMOE_SWE
(
torch
_s
tream
,
input_ready
[
step
]);
FMOE_SWE
(
smgr
->
torch
S
tream
()
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_ready
[
step
],
smgr
->
stream
(
0
));
cudaEventRecord
(
output_torch_ready
[
step
],
torch
_s
tream
);
cudaEventRecord
(
output_torch_ready
[
step
],
smgr
->
torch
S
tream
()
);
}
}
// Collect gradients for shadowed experts
// Collect gradients for shadowed experts
...
@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
torch
_s
tream
,
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
smgr
->
torch
S
tream
()
,
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
,
i
%
num_expert
);
set_grad_fn
(
si
,
i
%
num_expert
);
}
}
++
si
;
++
si
;
...
...
cuda/global_exchange.cpp
View file @
4f9f77f8
...
@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
...
@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64
,
ncclInt64
,
i
,
i
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
NCCL_SAFE_CALL
(
ncclRecv
(
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
global_expert_count
+
n_expert
*
i
,
n_expert
,
n_expert
,
ncclInt64
,
ncclInt64
,
i
,
i
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
}
torch
::
Tensor
_expert_exchange
(
torch
::
Tensor
_expert_exchange
(
...
...
cuda/global_exchange.h
View file @
4f9f77f8
...
@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
...
@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
if
(
global_expert_count
[
idx
])
{
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
NCCL_SAFE_CALL
(
ncclRecv
(
...
@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
...
@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
recv_ptr
+=
global_expert_count
[
idx
];
recv_ptr
+=
global_expert_count
[
idx
];
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
...
@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
send_ptr
+=
global_expert_count
[
idx
];
send_ptr
+=
global_expert_count
[
idx
];
}
}
if
(
local_expert_count
[
idx
])
{
if
(
local_expert_count
[
idx
])
{
...
@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
...
@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
s
tream
(
0
)));
smgr
->
torchS
tream
()));
}
}
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
...
...
cuda/local_exchange.cuh
View file @
4f9f77f8
...
@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
...
@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
size_t
numel
=
batch_size
*
topk
;
size_t
numel
=
batch_size
*
topk
;
assign_pos_kernel
assign_pos_kernel
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
s
tream
(
0
)
>>>
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
torchS
tream
()
>>>
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
smgr
->
sync
(
1
);
}
}
#define PERTHREAD_EXPERTS 256
#define PERTHREAD_EXPERTS 256
...
@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
...
@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const
size_t
batch_size
,
const
size_t
n_expert
,
const
size_t
batch_size
,
const
size_t
n_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
expert_count_kernel
expert_count_kernel
<<<
CEIL
(
n_expert
,
PERTHREAD_EXPERTS
),
256
,
0
,
smgr
->
s
tream
(
0
)
>>>
<<<
CEIL
(
n_expert
,
PERTHREAD_EXPERTS
),
256
,
0
,
smgr
->
torchS
tream
()
>>>
(
gate_idx
,
expert_count
,
batch_size
,
n_expert
);
(
gate_idx
,
expert_count
,
batch_size
,
n_expert
);
smgr
->
sync
(
1
);
}
}
cuda/parallel_linear.cuh
View file @
4f9f77f8
...
@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
...
@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
smgr
->
syncTorch
();
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
continue
;
...
@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
...
@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
smgr
->
syncTorch
();
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
// bias
// bias
...
...
cuda/stream_manager.cpp
View file @
4f9f77f8
...
@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
...
@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
}
}
cudaStream_t
CudaStreamManager
::
torchStream
()
{
return
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
}
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
cublasHandle_t
CudaStreamManager
::
handle
(
size_t
idx
)
{
if
(
this
->
use_default
)
{
if
(
this
->
use_default
)
{
return
at
::
cuda
::
getCurrentCUDABlasHandle
();
return
at
::
cuda
::
getCurrentCUDABlasHandle
();
...
@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
...
@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
}
}
void
CudaStreamManager
::
syncTorch
()
{
cudaStreamSynchronize
(
this
->
torchStream
());
}
void
CudaStreamManager
::
sync
(
int
idx
)
{
void
CudaStreamManager
::
sync
(
int
idx
)
{
if
(
this
->
use_default
)
{
if
(
this
->
use_default
)
{
return
;
return
;
...
...
cuda/stream_manager.h
View file @
4f9f77f8
...
@@ -34,8 +34,10 @@ public:
...
@@ -34,8 +34,10 @@ public:
void
setup
(
int
);
void
setup
(
int
);
void
sync
(
int
=
0
);
void
sync
(
int
=
0
);
void
syncTorch
();
void
destroy
();
void
destroy
();
cudaStream_t
torchStream
();
cudaStream_t
stream
(
size_t
=
0
);
cudaStream_t
stream
(
size_t
=
0
);
cublasHandle_t
handle
(
size_t
=
0
);
cublasHandle_t
handle
(
size_t
=
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment