Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
b8a212ef
Commit
b8a212ef
authored
Jan 09, 2021
by
Rick Ho
Browse files
fix bugs to pass test
parent
bf4388c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
13 deletions
+14
-13
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+14
-13
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
b8a212ef
...
@@ -66,9 +66,8 @@ void moe_cuda_forward_impl(
...
@@ -66,9 +66,8 @@ void moe_cuda_forward_impl(
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
auto
h
=
getCudaStreamManager
(
num_expert
);
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_init
);
timestamp
(
t_init
);
...
@@ -106,6 +105,7 @@ void moe_cuda_forward_impl(
...
@@ -106,6 +105,7 @@ void moe_cuda_forward_impl(
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
int
*
pos
=
new
int
[
batch_size
];
int
*
pos
=
new
int
[
batch_size
];
int
*
d_pos
;
int
*
d_pos
;
...
@@ -124,12 +124,12 @@ void moe_cuda_forward_impl(
...
@@ -124,12 +124,12 @@ void moe_cuda_forward_impl(
#endif
#endif
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getS
tream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
.
s
tream
s
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
input_buf
);
h
->
sync
(
0
);
// smgr.
sync(0);
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
h
->
sync
();
//
h->sync();
timestamp
(
t_scatter
);
timestamp
(
t_scatter
);
fprintf
(
stderr
,
"Scatter time %.3lf us
\n
"
,
getDuration
(
t_expert
,
t_scatter
)
*
fprintf
(
stderr
,
"Scatter time %.3lf us
\n
"
,
getDuration
(
t_expert
,
t_scatter
)
*
1e6
);
1e6
);
...
@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
...
@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat
);
in_feat
);
#endif
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
,
//
h->getHandle(i),
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
out_feat
,
expert_count
[
i
],
in_feat
,
...
@@ -167,11 +167,11 @@ void moe_cuda_forward_impl(
...
@@ -167,11 +167,11 @@ void moe_cuda_forward_impl(
1e6
);
1e6
);
#endif
#endif
h
->
sync
();
//
h->sync();
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
h
->
getS
tream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
<<<
batch_size
,
256
,
0
,
smgr
.
s
tream
s
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
output
);
h
->
sync
(
0
);
//
h->sync(0);
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_gather
);
timestamp
(
t_gather
);
...
@@ -203,8 +203,8 @@ void moe_cuda_grad_weight(
...
@@ -203,8 +203,8 @@ void moe_cuda_grad_weight(
scalar_t
alpha
=
1
,
beta
=
1
;
scalar_t
alpha
=
1
,
beta
=
1
;
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
checkCudaErrors
(
cublasSetStream
(
h
->
handles
[
0
],
*
(
h
->
streams
+
gate_host
[
i
]))
);
//
checkCudaErrors(cublasSetStream);
checkCudaErrors
(
cublasXgemm
(
h
->
handle
s
[
0
]
,
checkCudaErrors
(
cublasXgemm
(
smgr
.
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
out_feat
,
out_feat
,
...
@@ -253,7 +253,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -253,7 +253,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
batch_size
,
batch_size
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
num_expert
,
CUBLAS_OP_T
);
);
}));
}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment