Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
c5f73a0f
Commit
c5f73a0f
authored
Mar 26, 2021
by
TiagoMAntunes
Browse files
Bias moved to CUDA in forward. Basic bias setup for backwards (no kernel yet)
parent
6cdb3cda
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
50 deletions
+47
-50
cuda/moe.cpp
cuda/moe.cpp
+17
-28
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+27
-20
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+3
-2
No files found.
cuda/moe.cpp
View file @
c5f73a0f
...
@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather(
...
@@ -32,55 +32,44 @@ std::vector<torch::Tensor> moe_local_gather(
return
moe_cuda_local_gather
(
output_buf
,
pos
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
}
void
merge_bias
(
torch
::
Tensor
&
input_buf
,
torch
::
Tensor
&
weight
,
at
::
optional
<
torch
::
Tensor
>
bias_o
)
{
torch
::
Tensor
bias
=
bias_o
.
value
();
weight
=
at
::
cat
({
weight
,
bias
.
unsqueeze
(
2
)},
2
);
// [W b]
auto
options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
ones
=
at
::
ones
(
input_buf
.
size
(
0
),
options
).
unsqueeze
(
1
);
input_buf
=
at
::
cat
({
input_buf
,
ones
},
1
);
// [X 1]
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
// [
batch_size
]
torch
::
Tensor
expert_count
,
// [
num_expert
]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
)
{
// Wx+b = [W b] [x]
// [1]
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
);
// check if bias is valid in case it exists
if
(
bias_o
.
has_value
())
{
auto
bias
=
bias_o
.
value
();
CHECK_INPUT
(
bias
);
}
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
,
bias_o
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
// [num_expert]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
)
{
// Wx+b = [W b] [x]
// [1]
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
// check if bias is valid in case it exists
if
(
bias_o
.
has_value
())
{
auto
bias
=
bias_o
.
value
();
CHECK_INPUT
(
bias
);
}
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
expert_count
,
weight
,
bias_o
.
has_value
()
);
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
expert_count
,
weight
,
bias_o
);
}
}
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
...
...
cuda/moe_compute_kernel.cu
View file @
c5f73a0f
...
@@ -118,11 +118,12 @@ void moe_cuda_forward_impl(
...
@@ -118,11 +118,12 @@ void moe_cuda_forward_impl(
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
scalar_t
*
output_buf
,
const
bool
has_bias
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
...
@@ -154,6 +155,8 @@ void moe_cuda_backward_impl(
...
@@ -154,6 +155,8 @@ void moe_cuda_backward_impl(
const
long
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_bias
,
const
bool
has_bias
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
...
@@ -194,6 +197,10 @@ void moe_cuda_backward_impl(
...
@@ -194,6 +197,10 @@ void moe_cuda_backward_impl(
&
beta
,
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
));
if
(
has_bias
)
{
// call bias kernel here
}
ptr
+=
expert_count
[
i
];
ptr
+=
expert_count
[
i
];
}
}
...
@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -276,7 +283,8 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
torch
::
Tensor
weight
,
at
::
optional
<
torch
::
Tensor
>
bias
)
{
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
batch_size
=
input_buf
.
size
(
0
);
...
@@ -288,11 +296,18 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -288,11 +296,18 @@ std::vector<torch::Tensor> moe_cuda_forward(
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
torch
::
Tensor
output
;
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
if
(
bias
.
has_value
())
{
output
=
bias
.
value
().
repeat_interleave
(
expert_count
.
to
(
bias
.
value
().
device
()),
0
);
}
else
{
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
...
@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -300,6 +315,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
,
...
@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -315,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
bool
has_bias
at
::
optional
<
torch
::
Tensor
>
bias
)
{
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
batch_size
=
input_buf
.
size
(
0
);
...
@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -331,6 +347,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_backward_impl
<
scalar_t
>
(
moe_cuda_backward_impl
<
scalar_t
>
(
...
@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -340,6 +357,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
expert_count
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_bias
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
...
@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -348,17 +367,5 @@ std::vector<torch::Tensor> moe_cuda_backward(
);
);
}));
}));
if
(
!
has_bias
)
return
{
grad_input_buf
,
grad_weight
,
torch
::
empty
({
num_expert
,
out_feat
})};
return
{
grad_input_buf
,
grad_weight
,
grad_bias
};
// weight and input have been concatenated. need to split the grads back
// and separate them into input, weight, bias
torch
::
Tensor
grad_orig_input_buf
=
at
::
narrow
(
grad_input_buf
,
-
1
,
0
,
in_feat
-
1
).
contiguous
();
// bias is also squeezed in the new added dimension
torch
::
Tensor
grad_orig_bias
=
at
::
narrow
(
grad_weight
,
-
1
,
in_feat
-
1
,
1
).
squeeze
(
2
).
contiguous
();
torch
::
Tensor
grad_orig_weight
=
at
::
narrow
(
grad_weight
,
-
1
,
0
,
in_feat
-
1
).
contiguous
();
return
{
grad_orig_input_buf
,
grad_orig_weight
,
grad_orig_bias
};
}
}
cuda/moe_cuda_kernel.h
View file @
c5f73a0f
...
@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -20,14 +20,15 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
);
torch
::
Tensor
weight
,
at
::
optional
<
torch
::
Tensor
>
bias
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
bool
has_
bias
);
at
::
optional
<
torch
::
Tensor
>
bias
);
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment