Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8c213ef1
Unverified
Commit
8c213ef1
authored
Jul 31, 2023
by
Ilia Taraban
Committed by
GitHub
Jul 31, 2023
Browse files
[Feature] Enable bfloat16 convert functions in Python API (#5760)
parent
b6f5ba9a
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
183 additions
and
13 deletions
+183
-13
docs/source/guide/mixed_precision.rst
docs/source/guide/mixed_precision.rst
+94
-10
examples/pytorch/gat/train.py
examples/pytorch/gat/train.py
+13
-0
examples/pytorch/gcn/train.py
examples/pytorch/gcn/train.py
+12
-0
examples/pytorch/graphsage/node_classification.py
examples/pytorch/graphsage/node_classification.py
+12
-0
examples/pytorch/graphsage/train_full.py
examples/pytorch/graphsage/train_full.py
+13
-0
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+1
-0
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+1
-0
python/dgl/backend/tensorflow/tensor.py
python/dgl/backend/tensorflow/tensor.py
+1
-0
python/dgl/frame.py
python/dgl/frame.py
+12
-1
python/dgl/transforms/functional.py
python/dgl/transforms/functional.py
+19
-0
tests/python/common/test_heterograph.py
tests/python/common/test_heterograph.py
+5
-2
No files found.
docs/source/guide/mixed_precision.rst
View file @
8c213ef1
...
@@ -4,8 +4,8 @@ Chapter 8: Mixed Precision Training
...
@@ -4,8 +4,8 @@ Chapter 8: Mixed Precision Training
===================================
===================================
DGL
is
compatible
with
the
`
PyTorch
Automatic
Mixed
Precision
(
AMP
)
package
DGL
is
compatible
with
the
`
PyTorch
Automatic
Mixed
Precision
(
AMP
)
package
<
https
://
pytorch
.
org
/
docs
/
stable
/
amp
.
html
>`
_
<
https
://
pytorch
.
org
/
docs
/
stable
/
amp
.
html
>`
_
for
mixed
precision
training
,
thus
saving
both
training
time
and
GPU
memory
for
mixed
precision
training
,
thus
saving
both
training
time
and
GPU
/
CPU
memory
consumption
.
This
feature
requires
DGL
0.9
+.
consumption
.
This
feature
requires
DGL
0.9
+
and
1.1
+
for
CPU
bloat16
.
Message
-
Passing
with
Half
Precision
Message
-
Passing
with
Half
Precision
-----------------------------------
-----------------------------------
...
@@ -58,18 +58,19 @@ DGL relies on PyTorch's AMP package for mixed precision training,
...
@@ -58,18 +58,19 @@ DGL relies on PyTorch's AMP package for mixed precision training,
and the user experience is exactly
and the user experience is exactly
the same as `PyTorch'
s
<
https
://
pytorch
.
org
/
docs
/
stable
/
notes
/
amp_examples
.
html
>`
_
.
the same as `PyTorch'
s
<
https
://
pytorch
.
org
/
docs
/
stable
/
notes
/
amp_examples
.
html
>`
_
.
By
wrapping
the
forward
pass
with
``
torch
.
cuda
.
amp
.
autocast
()``,
PyTorch
automatically
By
wrapping
the
forward
pass
with
``
torch
.
amp
.
autocast
()``,
PyTorch
automatically
selects
the
appropriate
datatype
for
each
op
and
tensor
.
Half
precision
tensors
are
memory
selects
the
appropriate
datatype
for
each
op
and
tensor
.
Half
precision
tensors
are
memory
efficient
,
most
operators
on
half
precision
tensors
are
faster
as
they
leverage
GPU
tensorcores
.
efficient
,
most
operators
on
half
precision
tensors
are
faster
as
they
leverage
GPU
tensorcores
and
CPU
special
instructon
set
.
..
code
::
..
code
::
import
torch
.
nn
.
functional
as
F
import
torch
.
nn
.
functional
as
F
from
torch
.
cuda
.
amp
import
autocast
from
torch
.
amp
import
autocast
def
forward
(
g
,
feat
,
label
,
mask
,
model
,
amp_dtype
):
def
forward
(
device_type
,
g
,
feat
,
label
,
mask
,
model
,
amp_dtype
):
amp_enabled
=
amp_dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
amp_enabled
=
amp_dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
with
autocast
(
enabled
=
amp_enabled
,
dtype
=
amp_dtype
):
with
autocast
(
device_type
,
enabled
=
amp_enabled
,
dtype
=
amp_dtype
):
logit
=
model
(
g
,
feat
)
logit
=
model
(
g
,
feat
)
loss
=
F
.
cross_entropy
(
logit
[
mask
],
label
[
mask
])
loss
=
F
.
cross_entropy
(
logit
[
mask
],
label
[
mask
])
return
loss
return
loss
...
@@ -104,7 +105,7 @@ Pay attention to the differences in the code when AMP is activated or not.
...
@@ -104,7 +105,7 @@ Pay attention to the differences in the code when AMP is activated or not.
from dgl.nn import GATConv
from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop
from dgl.transforms import AddSelfLoop
amp_dtype = torch.float16
# or torch.
b
float16
amp_dtype = torch.
b
float16 # or torch.float16
class GAT(nn.Module):
class GAT(nn.Module):
def __init__(self,
def __init__(self,
...
@@ -130,7 +131,8 @@ Pay attention to the differences in the code when AMP is activated or not.
...
@@ -130,7 +131,8 @@ Pay attention to the differences in the code when AMP is activated or not.
# Data loading
# Data loading
transform = AddSelfLoop()
transform = AddSelfLoop()
data = RedditDataset(transform)
data = RedditDataset(transform)
dev = torch.device('
cuda
')
device_type = '
cuda
' # or '
cpu
'
dev = torch.device(device_type)
g = data[0]
g = data[0]
g = g.int().to(dev)
g = g.int().to(dev)
...
@@ -151,7 +153,7 @@ Pay attention to the differences in the code when AMP is activated or not.
...
@@ -151,7 +153,7 @@ Pay attention to the differences in the code when AMP is activated or not.
for epoch in range(100):
for epoch in range(100):
optimizer.zero_grad()
optimizer.zero_grad()
loss = forward(g, feat, label, train_mask, model, amp_dtype)
loss = forward(
device_type,
g, feat, label, train_mask, model, amp_dtype)
if amp_dtype == torch.float16:
if amp_dtype == torch.float16:
# Backprop w/ gradient scaling
# Backprop w/ gradient scaling
...
@@ -169,5 +171,87 @@ If we change the number of heads to ``[2, 2, 2]``, training without fp16
...
@@ -169,5 +171,87 @@ If we change the number of heads to ``[2, 2, 2]``, training without fp16
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
15.7G GPU memory.
15.7G GPU memory.
BFloat16 CPU example
-----------------------------------
DGL supports running training in the bfloat16 data type on the CPU.
This data type doesn'
t
require
any
CPU
feature
and
can
improve
the
performance
of
a
memory
-
bound
model
.
Starting
with
Intel
Xeon
4
th
Generation
,
which
has
`
AMX
<
https
://
www
.
intel
.
com
/
content
/
www
/
us
/
en
/
products
/
docs
/
accelerator
-
engines
/
advanced
-
matrix
-
extensions
/
overview
.
html
>`
_
instructon
set
,
bfloat16
should
significantly
improve
training
and
inference
performance
without
huge
code
changes
.
Here
is
an
example
of
simple
GCN
bfloat16
training
:
..
code
::
import
torch
import
torch
.
nn
as
nn
import
torch
.
nn
.
functional
as
F
import
dgl
from
dgl
.
data
import
CiteseerGraphDataset
from
dgl
.
nn
import
GraphConv
from
dgl
.
transforms
import
AddSelfLoop
class
GCN
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
hid_size
,
out_size
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
#
two
-
layer
GCN
self
.
layers
.
append
(
GraphConv
(
in_size
,
hid_size
,
activation
=
F
.
relu
)
)
self
.
layers
.
append
(
GraphConv
(
hid_size
,
out_size
))
self
.
dropout
=
nn
.
Dropout
(
0.5
)
def
forward
(
self
,
g
,
features
):
h
=
features
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
!= 0:
h
=
self
.
dropout
(
h
)
h
=
layer
(
g
,
h
)
return
h
#
Data
loading
transform
=
AddSelfLoop
()
data
=
CiteseerGraphDataset
(
transform
=
transform
)
g
=
data
[
0
]
g
=
g
.
int
()
train_mask
=
g
.
ndata
[
'train_mask'
]
feat
=
g
.
ndata
[
'feat'
]
label
=
g
.
ndata
[
'label'
]
in_size
=
feat
.
shape
[
1
]
hid_size
=
16
out_size
=
data
.
num_classes
model
=
GCN
(
in_size
,
hid_size
,
out_size
)
#
Convert
model
and
graph
to
bfloat16
g
=
dgl
.
to_bfloat16
(
g
)
feat
=
feat
.
to
(
dtype
=
torch
.
bfloat16
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
model
.
train
()
#
Create
optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-2
,
weight_decay
=
5e-4
)
loss_fcn
=
nn
.
CrossEntropyLoss
()
for
epoch
in
range
(
100
):
logits
=
model
(
g
,
feat
)
loss
=
loss_fcn
(
logits
[
train_mask
],
label
[
train_mask
])
loss
.
backward
()
optimizer
.
step
()
print
(
'Epoch {} | Loss {}'
.
format
(
epoch
,
loss
.
item
()))
The
only
difference
with
common
training
is
model
and
graph
conversion
before
training
/
inference
.
..
code
::
g
=
dgl
.
to_bfloat16
(
g
)
feat
=
feat
.
to
(
dtype
=
torch
.
bfloat16
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
DGL
is
still
improving
its
half
-
precision
support
and
the
compute
kernel
's
DGL
is
still
improving
its
half
-
precision
support
and
the
compute
kernel
's
performance is far from optimal, please stay tuned to our future updates.
performance is far from optimal, please stay tuned to our future updates.
examples/pytorch/gat/train.py
View file @
8c213ef1
import
argparse
import
argparse
import
dgl
import
dgl.nn
as
dglnn
import
dgl.nn
as
dglnn
import
torch
import
torch
...
@@ -88,6 +89,12 @@ if __name__ == "__main__":
...
@@ -88,6 +89,12 @@ if __name__ == "__main__":
default
=
"cora"
,
default
=
"cora"
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')."
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')."
,
)
)
parser
.
add_argument
(
"--dt"
,
type
=
str
,
default
=
"float"
,
help
=
"data type(float, bfloat16)"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"Training with DGL built-in GATConv module."
)
print
(
f
"Training with DGL built-in GATConv module."
)
...
@@ -115,6 +122,12 @@ if __name__ == "__main__":
...
@@ -115,6 +122,12 @@ if __name__ == "__main__":
out_size
=
data
.
num_classes
out_size
=
data
.
num_classes
model
=
GAT
(
in_size
,
8
,
out_size
,
heads
=
[
8
,
1
]).
to
(
device
)
model
=
GAT
(
in_size
,
8
,
out_size
,
heads
=
[
8
,
1
]).
to
(
device
)
# convert model and graph to bfloat16 if needed
if
args
.
dt
==
"bfloat16"
:
g
=
dgl
.
to_bfloat16
(
g
)
features
=
features
.
to
(
dtype
=
torch
.
bfloat16
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
# model training
# model training
print
(
"Training..."
)
print
(
"Training..."
)
train
(
g
,
features
,
labels
,
masks
,
model
)
train
(
g
,
features
,
labels
,
masks
,
model
)
...
...
examples/pytorch/gcn/train.py
View file @
8c213ef1
...
@@ -72,6 +72,12 @@ if __name__ == "__main__":
...
@@ -72,6 +72,12 @@ if __name__ == "__main__":
default
=
"cora"
,
default
=
"cora"
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')."
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')."
,
)
)
parser
.
add_argument
(
"--dt"
,
type
=
str
,
default
=
"float"
,
help
=
"data type(float, bfloat16)"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"Training with DGL built-in GraphConv module."
)
print
(
f
"Training with DGL built-in GraphConv module."
)
...
@@ -99,6 +105,12 @@ if __name__ == "__main__":
...
@@ -99,6 +105,12 @@ if __name__ == "__main__":
out_size
=
data
.
num_classes
out_size
=
data
.
num_classes
model
=
GCN
(
in_size
,
16
,
out_size
).
to
(
device
)
model
=
GCN
(
in_size
,
16
,
out_size
).
to
(
device
)
# convert model and graph to bfloat16 if needed
if
args
.
dt
==
"bfloat16"
:
g
=
dgl
.
to_bfloat16
(
g
)
features
=
features
.
to
(
dtype
=
torch
.
bfloat16
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
# model training
# model training
print
(
"Training..."
)
print
(
"Training..."
)
train
(
g
,
features
,
labels
,
masks
,
model
)
train
(
g
,
features
,
labels
,
masks
,
model
)
...
...
examples/pytorch/graphsage/node_classification.py
View file @
8c213ef1
...
@@ -58,6 +58,7 @@ class SAGE(nn.Module):
...
@@ -58,6 +58,7 @@ class SAGE(nn.Module):
y
=
torch
.
empty
(
y
=
torch
.
empty
(
g
.
num_nodes
(),
g
.
num_nodes
(),
self
.
hid_size
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
out_size
,
self
.
hid_size
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
out_size
,
dtype
=
feat
.
dtype
,
device
=
buffer_device
,
device
=
buffer_device
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
...
@@ -171,6 +172,12 @@ if __name__ == "__main__":
...
@@ -171,6 +172,12 @@ if __name__ == "__main__":
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
help
=
"Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training."
,
"'puregpu' for pure-GPU training."
,
)
)
parser
.
add_argument
(
"--dt"
,
type
=
str
,
default
=
"float"
,
help
=
"data type(float, bfloat16)"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
args
.
mode
=
"cpu"
args
.
mode
=
"cpu"
...
@@ -189,6 +196,11 @@ if __name__ == "__main__":
...
@@ -189,6 +196,11 @@ if __name__ == "__main__":
out_size
=
dataset
.
num_classes
out_size
=
dataset
.
num_classes
model
=
SAGE
(
in_size
,
256
,
out_size
).
to
(
device
)
model
=
SAGE
(
in_size
,
256
,
out_size
).
to
(
device
)
# convert model and graph to bfloat16 if needed
if
args
.
dt
==
"bfloat16"
:
g
=
dgl
.
to_bfloat16
(
g
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
# model training
# model training
print
(
"Training..."
)
print
(
"Training..."
)
train
(
args
,
device
,
g
,
dataset
,
model
,
num_classes
)
train
(
args
,
device
,
g
,
dataset
,
model
,
num_classes
)
...
...
examples/pytorch/graphsage/train_full.py
View file @
8c213ef1
import
argparse
import
argparse
import
dgl
import
dgl.nn
as
dglnn
import
dgl.nn
as
dglnn
import
torch
import
torch
...
@@ -69,6 +70,12 @@ if __name__ == "__main__":
...
@@ -69,6 +70,12 @@ if __name__ == "__main__":
default
=
"cora"
,
default
=
"cora"
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')"
,
help
=
"Dataset name ('cora', 'citeseer', 'pubmed')"
,
)
)
parser
.
add_argument
(
"--dt"
,
type
=
str
,
default
=
"float"
,
help
=
"data type(float, bfloat16)"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
f
"Training with DGL built-in GraphSage module"
)
print
(
f
"Training with DGL built-in GraphSage module"
)
...
@@ -96,6 +103,12 @@ if __name__ == "__main__":
...
@@ -96,6 +103,12 @@ if __name__ == "__main__":
out_size
=
data
.
num_classes
out_size
=
data
.
num_classes
model
=
SAGE
(
in_size
,
16
,
out_size
).
to
(
device
)
model
=
SAGE
(
in_size
,
16
,
out_size
).
to
(
device
)
# convert model and graph to bfloat16 if needed
if
args
.
dt
==
"bfloat16"
:
g
=
dgl
.
to_bfloat16
(
g
)
features
=
features
.
to
(
dtype
=
torch
.
bfloat16
)
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
)
# model training
# model training
print
(
"Training..."
)
print
(
"Training..."
)
train
(
g
,
features
,
labels
,
masks
,
model
)
train
(
g
,
features
,
labels
,
masks
,
model
)
...
...
python/dgl/backend/backend.py
View file @
8c213ef1
...
@@ -21,6 +21,7 @@ def data_type_dict():
...
@@ -21,6 +21,7 @@ def data_type_dict():
"""Returns a dictionary from data type string to the data type.
"""Returns a dictionary from data type string to the data type.
The dictionary should include at least:
The dictionary should include at least:
bfloat16
float16
float16
float32
float32
float64
float64
...
...
python/dgl/backend/pytorch/tensor.py
View file @
8c213ef1
...
@@ -18,6 +18,7 @@ if version.parse(th.__version__) < version.parse("1.12.0"):
...
@@ -18,6 +18,7 @@ if version.parse(th.__version__) < version.parse("1.12.0"):
def
data_type_dict
():
def
data_type_dict
():
return
{
return
{
"bfloat16"
:
th
.
bfloat16
,
"float16"
:
th
.
float16
,
"float16"
:
th
.
float16
,
"float32"
:
th
.
float32
,
"float32"
:
th
.
float32
,
"float64"
:
th
.
float64
,
"float64"
:
th
.
float64
,
...
...
python/dgl/backend/tensorflow/tensor.py
View file @
8c213ef1
...
@@ -30,6 +30,7 @@ def zerocopy_from_dlpack(dlpack_tensor):
...
@@ -30,6 +30,7 @@ def zerocopy_from_dlpack(dlpack_tensor):
def
data_type_dict
():
def
data_type_dict
():
return
{
return
{
"bfloat16"
:
tf
.
bfloat16
,
"float16"
:
tf
.
float16
,
"float16"
:
tf
.
float16
,
"float32"
:
tf
.
float32
,
"float32"
:
tf
.
float32
,
"float64"
:
tf
.
float64
,
"float64"
:
tf
.
float64
,
...
...
python/dgl/frame.py
View file @
8c213ef1
...
@@ -990,18 +990,29 @@ class Frame(MutableMapping):
...
@@ -990,18 +990,29 @@ class Frame(MutableMapping):
F
.
float64
,
F
.
float64
,
F
.
float32
,
F
.
float32
,
F
.
float16
,
F
.
float16
,
F
.
bfloat16
,
],
"'new_type' must be floating-point type: %s"
%
str
(
new_type
)
],
"'new_type' must be floating-point type: %s"
%
str
(
new_type
)
newframe
=
self
.
clone
()
newframe
=
self
.
clone
()
new_columns
=
{}
new_columns
=
{}
for
name
,
column
in
self
.
_columns
.
items
():
for
name
,
column
in
self
.
_columns
.
items
():
dtype
=
column
.
dtype
dtype
=
column
.
dtype
if
dtype
!=
new_type
and
dtype
in
[
F
.
float64
,
F
.
float32
,
F
.
float16
]:
if
dtype
!=
new_type
and
dtype
in
[
F
.
float64
,
F
.
float32
,
F
.
float16
,
F
.
bfloat16
,
]:
new_columns
[
name
]
=
column
.
astype
(
new_type
)
new_columns
[
name
]
=
column
.
astype
(
new_type
)
else
:
else
:
new_columns
[
name
]
=
column
new_columns
[
name
]
=
column
newframe
.
_columns
=
new_columns
newframe
.
_columns
=
new_columns
return
newframe
return
newframe
def
bfloat16
(
self
):
"""Return a new frame with all floating-point columns converted
to bfloat16"""
return
self
.
_astype_float
(
F
.
bfloat16
)
def
half
(
self
):
def
half
(
self
):
"""Return a new frame with all floating-point columns converted
"""Return a new frame with all floating-point columns converted
to half-precision (float16)"""
to half-precision (float16)"""
...
...
python/dgl/transforms/functional.py
View file @
8c213ef1
...
@@ -86,6 +86,7 @@ __all__ = [
...
@@ -86,6 +86,7 @@ __all__ = [
"random_walk_pe"
,
"random_walk_pe"
,
"laplacian_pe"
,
"laplacian_pe"
,
"lap_pe"
,
"lap_pe"
,
"to_bfloat16"
,
"to_half"
,
"to_half"
,
"to_float"
,
"to_float"
,
"to_double"
,
"to_double"
,
...
@@ -3711,6 +3712,24 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
...
@@ -3711,6 +3712,24 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
return
lap_pe
(
g
,
k
,
padding
,
return_eigval
)
return
lap_pe
(
g
,
k
,
padding
,
return_eigval
)
def
to_bfloat16
(
g
):
r
"""Cast this graph to use bfloat16 for any
floating-point edge and node feature data.
A shallow copy is returned so that the original graph is not modified.
Feature tensors that are not floating-point will not be modified.
Returns
-------
DGLGraph
Clone of graph with the feature data converted to float16.
"""
ret
=
copy
.
copy
(
g
)
ret
.
_edge_frames
=
[
frame
.
bfloat16
()
for
frame
in
ret
.
_edge_frames
]
ret
.
_node_frames
=
[
frame
.
bfloat16
()
for
frame
in
ret
.
_node_frames
]
return
ret
def
to_half
(
g
):
def
to_half
(
g
):
r
"""Cast this graph to use float16 (half-precision) for any
r
"""Cast this graph to use float16 (half-precision) for any
floating-point edge and node feature data.
floating-point edge and node feature data.
...
...
tests/python/common/test_heterograph.py
View file @
8c213ef1
...
@@ -2443,7 +2443,7 @@ def test_dtype_cast(idtype):
...
@@ -2443,7 +2443,7 @@ def test_dtype_cast(idtype):
def
test_float_cast
():
def
test_float_cast
():
for
t
in
[
F
.
float16
,
F
.
float32
,
F
.
float64
]:
for
t
in
[
F
.
bfloat16
,
F
.
float16
,
F
.
float32
,
F
.
float64
]:
idtype
=
F
.
int32
idtype
=
F
.
int32
g
=
dgl
.
heterograph
(
g
=
dgl
.
heterograph
(
{
{
...
@@ -2469,6 +2469,7 @@ def test_float_cast():
...
@@ -2469,6 +2469,7 @@ def test_float_cast():
(
"c"
,
F
.
float64
),
(
"c"
,
F
.
float64
),
(
"d"
,
F
.
int32
),
(
"d"
,
F
.
int32
),
(
"e"
,
F
.
int64
),
(
"e"
,
F
.
int64
),
(
"f"
,
F
.
bfloat16
),
]
]
for
name
,
type
in
dataNamesTypes
:
for
name
,
type
in
dataNamesTypes
:
g
.
nodes
[
"user"
].
data
[
name
]
=
F
.
copy_to
(
g
.
nodes
[
"user"
].
data
[
name
]
=
F
.
copy_to
(
...
@@ -2487,6 +2488,8 @@ def test_float_cast():
...
@@ -2487,6 +2488,8 @@ def test_float_cast():
F
.
tensor
(
pvalues
,
dtype
=
type
),
ctx
=
F
.
ctx
()
F
.
tensor
(
pvalues
,
dtype
=
type
),
ctx
=
F
.
ctx
()
)
)
if
t
==
F
.
bfloat16
:
g
=
dgl
.
transforms
.
functional
.
to_bfloat16
(
g
)
if
t
==
F
.
float16
:
if
t
==
F
.
float16
:
g
=
dgl
.
transforms
.
functional
.
to_half
(
g
)
g
=
dgl
.
transforms
.
functional
.
to_half
(
g
)
if
t
==
F
.
float32
:
if
t
==
F
.
float32
:
...
@@ -2498,7 +2501,7 @@ def test_float_cast():
...
@@ -2498,7 +2501,7 @@ def test_float_cast():
# integer tensors shouldn't be converted
# integer tensors shouldn't be converted
reqType
=
(
reqType
=
(
t
t
if
(
origType
in
[
F
.
float16
,
F
.
float32
,
F
.
float64
])
if
(
origType
in
[
F
.
bfloat16
,
F
.
float16
,
F
.
float32
,
F
.
float64
])
else
origType
else
origType
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment