Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d56522bc
Unverified
Commit
d56522bc
authored
Jun 06, 2023
by
Mingshu Zhai
Committed by
GitHub
Jun 06, 2023
Browse files
Merge pull request #160 from laekov/nvbf16
bf16 support
parents
c9ccc0eb
0bb60881
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
56 additions
and
26 deletions
+56
-26
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+4
-4
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+4
-4
cuda/parallel_linear.cu
cuda/parallel_linear.cu
+4
-2
cuda/utils/cublas_wrapper.h
cuda/utils/cublas_wrapper.h
+22
-0
setup.py
setup.py
+1
-1
tests/moe.py
tests/moe.py
+2
-2
tests/test_mimo.py
tests/test_mimo.py
+4
-4
tests/test_numerical.py
tests/test_numerical.py
+15
-9
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
d56522bc
...
@@ -52,8 +52,8 @@ void _reduce_grad(
...
@@ -52,8 +52,8 @@ void _reduce_grad(
cudaEventDestroy
(
evt_stash
);
cudaEventDestroy
(
evt_stash
);
auto
dtype
=
getNcclDataType
(
t
.
scalar_type
());
auto
dtype
=
getNcclDataType
(
t
.
scalar_type
());
AT_DISPATCH_FLOATING_TYPES_AND
_HALF
(
t
.
scalar_type
()
,
AT_DISPATCH_FLOATING_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
"fmoe_cuda_reduce_grad"
,
([
&
]
{
t
.
scalar_type
(),
"fmoe_cuda_reduce_grad"
,
([
&
]
{
void
*
buf
=
(
void
*
)
t
.
data_ptr
<
scalar_t
>
();
void
*
buf
=
(
void
*
)
t
.
data_ptr
<
scalar_t
>
();
NCCL_SAFE_CALL
(
ncclReduce
(
buf
,
buf
,
expert_size
,
NCCL_SAFE_CALL
(
ncclReduce
(
buf
,
buf
,
expert_size
,
dtype
,
dtype
,
...
@@ -110,8 +110,8 @@ std::vector<torch::Tensor> _smart_sch_forward(
...
@@ -110,8 +110,8 @@ std::vector<torch::Tensor> _smart_sch_forward(
}
}
}
}
AT_DISPATCH_FLOATING_TYPES_AND
_HALF
(
input_buf
.
s
calar
_t
ype
()
,
AT_DISPATCH_FLOATING_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
S
calar
T
ype
::
BFloat16
,
"fmoe_cuda_smart_sch_forward"
,
([
&
]
{
input_buf
.
scalar_type
(),
"fmoe_cuda_smart_sch_forward"
,
([
&
]
{
fmoe_cuda_fused_forward_impl
(
fmoe_cuda_fused_forward_impl
(
forward_fn
,
forward_fn
,
stash_fn
,
stash_fn
,
...
...
cuda/global_exchange.cpp
View file @
d56522bc
...
@@ -58,8 +58,8 @@ torch::Tensor _global_scatter(
...
@@ -58,8 +58,8 @@ torch::Tensor _global_scatter(
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND
_HALF
(
input_buf
.
s
calar
_t
ype
()
,
AT_DISPATCH_FLOATING_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
S
calar
T
ype
::
BFloat16
,
"fmoe_cuda_global_scatter"
,
([
&
]
{
input_buf
.
scalar_type
(),
"fmoe_cuda_global_scatter"
,
([
&
]
{
fmoe_cuda_global_scatter_impl
<
scalar_t
>
(
fmoe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
...
@@ -84,8 +84,8 @@ torch::Tensor _global_gather(
...
@@ -84,8 +84,8 @@ torch::Tensor _global_gather(
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND
_HALF
(
output_buf
.
s
calar
_t
ype
()
,
AT_DISPATCH_FLOATING_TYPES_AND
2
(
at
::
ScalarType
::
Half
,
at
::
S
calar
T
ype
::
BFloat16
,
"fmoe_cuda_global_gather"
,
([
&
]
{
output_buf
.
scalar_type
(),
"fmoe_cuda_global_gather"
,
([
&
]
{
fmoe_cuda_global_gather_impl
<
scalar_t
>
(
fmoe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
...
...
cuda/parallel_linear.cu
View file @
d56522bc
...
@@ -30,7 +30,8 @@ torch::Tensor _linear_forward(
...
@@ -30,7 +30,8 @@ torch::Tensor _linear_forward(
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
([
&
]
{
fmoe_cuda_linear_forward_impl
<
scalar_t
>
(
fmoe_cuda_linear_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
@@ -72,7 +73,8 @@ std::vector<torch::Tensor> _linear_backward(
...
@@ -72,7 +73,8 @@ std::vector<torch::Tensor> _linear_backward(
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
fmoe_cuda_linear_backward_impl
<
scalar_t
>
(
fmoe_cuda_linear_backward_impl
<
scalar_t
>
(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
...
cuda/utils/cublas_wrapper.h
View file @
d56522bc
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <c10/util/Half.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
inline
cublasStatus_t
cublasXgemmBatched
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transa
,
...
@@ -108,5 +109,26 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
...
@@ -108,5 +109,26 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
(
__half
*
)
C
,
ldc
);
(
__half
*
)
C
,
ldc
);
#endif
#endif
}
}
inline
cublasStatus_t
cublasXgemm
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
c10
::
BFloat16
*
alpha
,
const
c10
::
BFloat16
*
A
,
int
lda
,
const
c10
::
BFloat16
*
B
,
int
ldb
,
const
c10
::
BFloat16
*
beta
,
c10
::
BFloat16
*
C
,
int
ldc
)
{
#ifdef FMOE_USE_HIP
// TODO: Support bf16 for HIP
assert
(
false
);
#else
return
cublasSgemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
float
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
lda
,
(
const
void
*
)
B
,
CUDA_R_16F
,
ldb
,
(
const
float
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
ldc
);
#endif
}
#endif // CUBLAS_WRAPPER_H
#endif // CUBLAS_WRAPPER_H
setup.py
View file @
d56522bc
...
@@ -41,7 +41,7 @@ else:
...
@@ -41,7 +41,7 @@ else:
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
setuptools
.
setup
(
setuptools
.
setup
(
name
=
'fastmoe'
,
name
=
'fastmoe'
,
version
=
'1.0.
1
'
,
version
=
'1.0.
2
'
,
description
=
'An efficient Mixture-of-Experts system for PyTorch'
,
description
=
'An efficient Mixture-of-Experts system for PyTorch'
,
author
=
', '
.
join
(
authors
),
author
=
', '
.
join
(
authors
),
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
...
...
tests/moe.py
View file @
d56522bc
...
@@ -80,7 +80,7 @@ class NaiveExpert(nn.Module):
...
@@ -80,7 +80,7 @@ class NaiveExpert(nn.Module):
super
(
NaiveExpert
,
self
).
__init__
()
super
(
NaiveExpert
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
self
.
linear
=
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
fec
=
None
):
return
self
.
linear
(
x
)
return
self
.
linear
(
x
)
...
@@ -91,5 +91,5 @@ class LinearExpert(nn.Module):
...
@@ -91,5 +91,5 @@ class LinearExpert(nn.Module):
nn
.
Linear
(
d_model
,
d_model
*
2
),
nn
.
ReLU
(),
nn
.
Linear
(
d_model
*
2
,
d_model
),
nn
.
Linear
(
d_model
,
d_model
*
2
),
nn
.
ReLU
(),
nn
.
Linear
(
d_model
*
2
,
d_model
),
).
cuda
()
).
cuda
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
fec
=
None
):
return
self
.
model
(
x
)
return
self
.
model
(
x
)
tests/test_mimo.py
View file @
d56522bc
...
@@ -108,7 +108,7 @@ class MyMoE(FMoE):
...
@@ -108,7 +108,7 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
"
torch.
F
loat
Tensor"
,
"
torch.
DoubleTensor"
,
"
torch.
HalfTensor"
]
"data_type"
,
[
torch
.
f
loat
32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
double
]
)
)
@
pytest
.
mark
.
parametrize
(
"list_input"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"list_input"
,
[
False
,
True
])
def
test_fmoe_mimo_linear
(
def
test_fmoe_mimo_linear
(
...
@@ -138,9 +138,9 @@ def test_fmoe_mimo_linear(
...
@@ -138,9 +138,9 @@ def test_fmoe_mimo_linear(
mp_group
=
mp_group
,
mp_group
=
mp_group
,
top_k
=
top_k
,
top_k
=
top_k
,
activation
=
activation
,
activation
=
activation
,
).
cuda
()
).
cuda
()
.
to
(
data_type
)
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
.
to
(
data_type
)
inp
=
[
x
,
x
.
clone
()]
if
list_input
else
{
"x"
:
x
,
"y"
:
x
.
clone
()}
inp
=
[
x
,
x
.
clone
()]
if
list_input
else
{
"x"
:
x
,
"y"
:
x
.
clone
()}
moe_out
=
moe
(
inp
)
moe_out
=
moe
(
inp
)
...
@@ -162,6 +162,6 @@ if __name__ == "__main__":
...
@@ -162,6 +162,6 @@ if __name__ == "__main__":
mp_group
=
None
,
mp_group
=
None
,
dp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
world_group
=
None
,
data_type
=
torch
.
float
32
,
data_type
=
torch
.
b
float
16
,
list_input
=
True
list_input
=
True
)
)
tests/test_numerical.py
View file @
d56522bc
...
@@ -51,6 +51,8 @@ def _perform_forward(
...
@@ -51,6 +51,8 @@ def _perform_forward(
def
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
def
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
max
()
err
=
(
mo
-
ro
).
abs
().
max
()
if
err
.
dtype
==
torch
.
bfloat16
:
precision
*=
100
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
precision
:
if
err
>
precision
:
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
...
@@ -217,6 +219,7 @@ def test_fmoe_linear(
...
@@ -217,6 +219,7 @@ def test_fmoe_linear(
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_fmoe
(
def
test_fmoe
(
batch_size
,
batch_size
,
num_expert
,
num_expert
,
...
@@ -228,6 +231,7 @@ def test_fmoe(
...
@@ -228,6 +231,7 @@ def test_fmoe(
mp_group
,
mp_group
,
dp_group
,
dp_group
,
world_group
,
world_group
,
data_type
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
...
@@ -243,7 +247,7 @@ def test_fmoe(
...
@@ -243,7 +247,7 @@ def test_fmoe(
mp_group
=
mp_group
,
mp_group
=
mp_group
,
expert
=
expert
,
expert
=
expert
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
()
).
cuda
()
.
to
(
data_type
)
moe_raw
=
BruteForceMoE
(
moe_raw
=
BruteForceMoE
(
expert
=
expert
,
expert
=
expert
,
...
@@ -251,7 +255,7 @@ def test_fmoe(
...
@@ -251,7 +255,7 @@ def test_fmoe(
d_model
=
d_model
,
d_model
=
d_model
,
world_size
=
world_size
,
world_size
=
world_size
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
()
).
cuda
()
.
to
(
data_type
)
if
world_size
==
1
:
if
world_size
==
1
:
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
...
@@ -275,7 +279,7 @@ def test_fmoe(
...
@@ -275,7 +279,7 @@ def test_fmoe(
].
data
=
para_tensor_gathered
[
expertID
]
].
data
=
para_tensor_gathered
[
expertID
]
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
)
)
def
get_experts_grad
(
experts
:
List
[
nn
.
Module
]):
def
get_experts_grad
(
experts
:
List
[
nn
.
Module
]):
...
@@ -396,6 +400,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
...
@@ -396,6 +400,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
torch
.
float32
])
def
test_fmoe_experts
(
def
test_fmoe_experts
(
batch_size
,
batch_size
,
num_expert
,
num_expert
,
...
@@ -407,6 +412,7 @@ def test_fmoe_experts(
...
@@ -407,6 +412,7 @@ def test_fmoe_experts(
mp_group
,
mp_group
,
dp_group
,
dp_group
,
world_group
,
world_group
,
data_type
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
...
@@ -422,7 +428,7 @@ def test_fmoe_experts(
...
@@ -422,7 +428,7 @@ def test_fmoe_experts(
mp_group
=
mp_group
,
mp_group
=
mp_group
,
expert
=
expert
,
expert
=
expert
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
()
).
cuda
()
.
to
(
data_type
)
moe_raw
=
BruteForceMoE
(
moe_raw
=
BruteForceMoE
(
expert
=
expert
,
expert
=
expert
,
...
@@ -430,7 +436,7 @@ def test_fmoe_experts(
...
@@ -430,7 +436,7 @@ def test_fmoe_experts(
d_model
=
d_model
,
d_model
=
d_model
,
world_size
=
world_size
,
world_size
=
world_size
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
()
).
cuda
()
.
to
(
data_type
)
if
world_size
==
1
:
if
world_size
==
1
:
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
...
@@ -454,7 +460,7 @@ def test_fmoe_experts(
...
@@ -454,7 +460,7 @@ def test_fmoe_experts(
].
data
=
para_tensor_gathered
[
expertID
]
].
data
=
para_tensor_gathered
[
expertID
]
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
)
)
def
get_experts_grad
(
experts
:
List
[
nn
.
Module
]):
def
get_experts_grad
(
experts
:
List
[
nn
.
Module
]):
...
@@ -488,16 +494,16 @@ def test_fmoe_experts(
...
@@ -488,16 +494,16 @@ def test_fmoe_experts(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fmoe
_linear
(
test_fmoe
(
batch_size
=
2
,
batch_size
=
2
,
num_expert
=
2
,
num_expert
=
2
,
d_model
=
2
,
d_model
=
2
,
top_k
=
2
,
top_k
=
2
,
d_hidden
=
16
,
expert
=
[
NaiveExpert
for
_
in
range
(
4
)]
,
rank
=
0
,
rank
=
0
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
dp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
world_group
=
None
,
data_type
=
torch
.
float
32
,
data_type
=
torch
.
b
float
16
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment