Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
143e21cc
Commit
143e21cc
authored
Jan 09, 2021
by
Rick Ho
Browse files
update stream manager
parent
2565f2fa
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
57 deletions
+8
-57
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+1
-1
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+3
-52
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+3
-3
pytorch/cuda/setup.py
pytorch/cuda/setup.py
+1
-1
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
143e21cc
...
@@ -9,7 +9,7 @@ void CudaStreamManager::sync(int i) {
...
@@ -9,7 +9,7 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
return
;
return
;
}
}
for
(
size_t
i
=
0
;
i
<
MAX_STREAMS
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
this
->
num_expert
;
++
i
)
{
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
}
}
}
}
pytorch/cuda/moe_cuda_kernel.cu
View file @
143e21cc
...
@@ -69,10 +69,6 @@ void moe_cuda_forward_impl(
...
@@ -69,10 +69,6 @@ void moe_cuda_forward_impl(
const
size_t
num_expert
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
cublasOperation_t
transb
)
{
#ifdef MOE_BREAKDOWN
timestamp
(
t_init
);
#endif
scalar_t
*
input_buf
,
*
output_buf
;
scalar_t
*
input_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
...
@@ -80,12 +76,6 @@ void moe_cuda_forward_impl(
...
@@ -80,12 +76,6 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
out_feat
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_malloc
);
fprintf
(
stderr
,
"Malloc time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_malloc
)
*
1e6
);
#endif
int
*
gate
=
new
int
[
batch_size
];
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_count
=
new
int
[
num_expert
],
*
expert_ptr
=
new
int
[
num_expert
];
int
*
expert_count
=
new
int
[
num_expert
],
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
...
@@ -93,12 +83,6 @@ void moe_cuda_forward_impl(
...
@@ -93,12 +83,6 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_cpy
);
fprintf
(
stderr
,
"Copy time %.3lf us
\n
"
,
getDuration
(
t_malloc
,
t_cpy
)
*
1e6
);
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
++
expert_count
[
gate
[
i
]];
}
}
...
@@ -117,23 +101,10 @@ void moe_cuda_forward_impl(
...
@@ -117,23 +101,10 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_expert
);
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_cpy
,
t_expert
)
*
1e6
);
#endif
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
input_buf
);
// smgr.sync(0);
smgr
.
sync
(
0
);
#ifdef MOE_BREAKDOWN
// h->sync();
timestamp
(
t_scatter
);
fprintf
(
stderr
,
"Scatter time %.3lf us
\n
"
,
getDuration
(
t_expert
,
t_scatter
)
*
1e6
);
#endif
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
...
@@ -141,13 +112,8 @@ void moe_cuda_forward_impl(
...
@@ -141,13 +112,8 @@ void moe_cuda_forward_impl(
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
continue
;
}
}
#ifdef MOE_DEBUG_SCATTER
fprintf
(
stderr
,
"gemm %d sz %d
\n
"
,
i
,
expert_count
[
i
]);
fprintf
(
stderr
,
"GeMM %d x %d x %d
\n
"
,
out_feat
,
expert_count
[
i
],
in_feat
);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
smgr
.
handles
[
0
],
// h->getHandle(i)
,
checkCudaErrors
(
cublasXgemm
(
smgr
.
handles
[
i
]
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
out_feat
,
expert_count
[
i
],
in_feat
,
...
@@ -161,25 +127,10 @@ void moe_cuda_forward_impl(
...
@@ -161,25 +127,10 @@ void moe_cuda_forward_impl(
ptr
+=
expert_count
[
i
];
ptr
+=
expert_count
[
i
];
}
}
#ifdef MOE_BREAKDOWN
timestamp
(
t_mm
);
fprintf
(
stderr
,
"GeMM time %.3lf us
\n
"
,
getDuration
(
t_scatter
,
t_mm
)
*
1e6
);
#endif
// h->sync();
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
<<<
batch_size
,
256
,
0
,
smgr
.
streams
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
output
);
// h->sync(0);
smgr
.
sync
(
0
);
#ifdef MOE_BREAKDOWN
timestamp
(
t_gather
);
fprintf
(
stderr
,
"Gather time %.3lf us
\n
"
,
getDuration
(
t_mm
,
t_gather
)
*
1e6
);
fprintf
(
stderr
,
"Overall time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_gather
)
*
1e6
);
#endif
cudaFree
(
input_buf
);
cudaFree
(
input_buf
);
cudaFree
(
output_buf
);
cudaFree
(
output_buf
);
...
...
pytorch/cuda/moe_test.py
View file @
143e21cc
...
@@ -6,8 +6,8 @@ import sys
...
@@ -6,8 +6,8 @@ import sys
def
perf
():
def
perf
():
batch_size
=
int
(
sys
.
argv
[
1
])
batch_size
=
int
(
sys
.
argv
[
1
])
i
o
_feat
=
int
(
sys
.
argv
[
2
])
i
n
_feat
=
int
(
sys
.
argv
[
2
])
hidden
_feat
=
int
(
sys
.
argv
[
3
])
out
_feat
=
int
(
sys
.
argv
[
3
])
num_expert
=
int
(
sys
.
argv
[
4
])
num_expert
=
int
(
sys
.
argv
[
4
])
...
@@ -36,7 +36,7 @@ def perf():
...
@@ -36,7 +36,7 @@ def perf():
sqtot
+=
(
te
-
ts
)
**
2
sqtot
+=
(
te
-
ts
)
**
2
maxt
=
max
(
maxt
,
te
-
ts
)
maxt
=
max
(
maxt
,
te
-
ts
)
gflops
=
2e-9
*
n_runs
*
i
o
_feat
*
hidden
_feat
*
2
*
batch_size
/
tott
gflops
=
2e-9
*
n_runs
*
i
n
_feat
*
out
_feat
*
batch_size
/
tott
print
(
'Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'
.
format
(
print
(
'Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'
.
format
(
tott
*
1e3
/
n_runs
,
maxt
*
1e3
,
tott
*
1e3
/
n_runs
,
maxt
*
1e3
,
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
/
n_runs
,
gflops
))
(
sqtot
/
n_runs
-
(
tott
/
n_runs
)
**
2
)
*
1e3
/
n_runs
,
gflops
))
...
...
pytorch/cuda/setup.py
View file @
143e21cc
...
@@ -11,7 +11,7 @@ setup(
...
@@ -11,7 +11,7 @@ setup(
name
=
'moe_cuda'
,
name
=
'moe_cuda'
,
sources
=
[
sources
=
[
'moe.cpp'
,
'moe.cpp'
,
#
'cuda_stream_manager.cpp',
'cuda_stream_manager.cpp'
,
'moe_cuda_kernel.cu'
,
'moe_cuda_kernel.cu'
,
],
],
extra_compile_args
=
{
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
)],
extra_compile_args
=
{
'cxx'
:
[
'-I{}'
.
format
(
CUDA_HELPER
)],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment