"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "866c70dace72c8d73e2011b8881fcf0a6eea4be7"
Unverified Commit 8c213ef1 authored by Ilia Taraban's avatar Ilia Taraban Committed by GitHub
Browse files

[Feature] Enable bfloat16 convert functions in Python API (#5760)

parent b6f5ba9a
...@@ -4,8 +4,8 @@ Chapter 8: Mixed Precision Training ...@@ -4,8 +4,8 @@ Chapter 8: Mixed Precision Training
=================================== ===================================
DGL is compatible with the `PyTorch Automatic Mixed Precision (AMP) package DGL is compatible with the `PyTorch Automatic Mixed Precision (AMP) package
<https://pytorch.org/docs/stable/amp.html>`_ <https://pytorch.org/docs/stable/amp.html>`_
for mixed precision training, thus saving both training time and GPU memory for mixed precision training, thus saving both training time and GPU/CPU memory
consumption. This feature requires DGL 0.9+. consumption. This feature requires DGL 0.9+ and 1.1+ for CPU bloat16.
Message-Passing with Half Precision Message-Passing with Half Precision
----------------------------------- -----------------------------------
...@@ -58,18 +58,19 @@ DGL relies on PyTorch's AMP package for mixed precision training, ...@@ -58,18 +58,19 @@ DGL relies on PyTorch's AMP package for mixed precision training,
and the user experience is exactly and the user experience is exactly
the same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_. the same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_.
By wrapping the forward pass with ``torch.cuda.amp.autocast()``, PyTorch automatically By wrapping the forward pass with ``torch.amp.autocast()``, PyTorch automatically
selects the appropriate datatype for each op and tensor. Half precision tensors are memory selects the appropriate datatype for each op and tensor. Half precision tensors are memory
efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores. efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores
and CPU special instructon set.
.. code:: .. code::
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast from torch.amp import autocast
def forward(g, feat, label, mask, model, amp_dtype): def forward(device_type, g, feat, label, mask, model, amp_dtype):
amp_enabled = amp_dtype in (torch.float16, torch.bfloat16) amp_enabled = amp_dtype in (torch.float16, torch.bfloat16)
with autocast(enabled=amp_enabled, dtype=amp_dtype): with autocast(device_type, enabled=amp_enabled, dtype=amp_dtype):
logit = model(g, feat) logit = model(g, feat)
loss = F.cross_entropy(logit[mask], label[mask]) loss = F.cross_entropy(logit[mask], label[mask])
return loss return loss
...@@ -104,7 +105,7 @@ Pay attention to the differences in the code when AMP is activated or not. ...@@ -104,7 +105,7 @@ Pay attention to the differences in the code when AMP is activated or not.
from dgl.nn import GATConv from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop from dgl.transforms import AddSelfLoop
amp_dtype = torch.float16 # or torch.bfloat16 amp_dtype = torch.bfloat16 # or torch.float16
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -130,7 +131,8 @@ Pay attention to the differences in the code when AMP is activated or not. ...@@ -130,7 +131,8 @@ Pay attention to the differences in the code when AMP is activated or not.
# Data loading # Data loading
transform = AddSelfLoop() transform = AddSelfLoop()
data = RedditDataset(transform) data = RedditDataset(transform)
dev = torch.device('cuda') device_type = 'cuda' # or 'cpu'
dev = torch.device(device_type)
g = data[0] g = data[0]
g = g.int().to(dev) g = g.int().to(dev)
...@@ -151,7 +153,7 @@ Pay attention to the differences in the code when AMP is activated or not. ...@@ -151,7 +153,7 @@ Pay attention to the differences in the code when AMP is activated or not.
for epoch in range(100): for epoch in range(100):
optimizer.zero_grad() optimizer.zero_grad()
loss = forward(g, feat, label, train_mask, model, amp_dtype) loss = forward(device_type, g, feat, label, train_mask, model, amp_dtype)
if amp_dtype == torch.float16: if amp_dtype == torch.float16:
# Backprop w/ gradient scaling # Backprop w/ gradient scaling
...@@ -169,5 +171,87 @@ If we change the number of heads to ``[2, 2, 2]``, training without fp16 ...@@ -169,5 +171,87 @@ If we change the number of heads to ``[2, 2, 2]``, training without fp16
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
15.7G GPU memory. 15.7G GPU memory.
BFloat16 CPU example
-----------------------------------
DGL supports running training in the bfloat16 data type on the CPU.
This data type doesn't require any CPU feature and can improve the performance of a memory-bound model.
Starting with Intel Xeon 4th Generation, which has `AMX
<https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ instructon set, bfloat16 should significantly improve training and inference performance without huge code changes.
Here is an example of simple GCN bfloat16 training:
.. code::
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import CiteseerGraphDataset
from dgl.nn import GraphConv
from dgl.transforms import AddSelfLoop
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# two-layer GCN
self.layers.append(
GraphConv(in_size, hid_size, activation=F.relu)
)
self.layers.append(GraphConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
def forward(self, g, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(g, h)
return h
# Data loading
transform = AddSelfLoop()
data = CiteseerGraphDataset(transform=transform)
g = data[0]
g = g.int()
train_mask = g.ndata['train_mask']
feat = g.ndata['feat']
label = g.ndata['label']
in_size = feat.shape[1]
hid_size = 16
out_size = data.num_classes
model = GCN(in_size, hid_size, out_size)
# Convert model and graph to bfloat16
g = dgl.to_bfloat16(g)
feat = feat.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
model.train()
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
loss_fcn = nn.CrossEntropyLoss()
for epoch in range(100):
logits = model(g, feat)
loss = loss_fcn(logits[train_mask], label[train_mask])
loss.backward()
optimizer.step()
print('Epoch {} | Loss {}'.format(epoch, loss.item()))
The only difference with common training is model and graph conversion before training/inference.
.. code::
g = dgl.to_bfloat16(g)
feat = feat.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
DGL is still improving its half-precision support and the compute kernel's DGL is still improving its half-precision support and the compute kernel's
performance is far from optimal, please stay tuned to our future updates. performance is far from optimal, please stay tuned to our future updates.
import argparse import argparse
import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import torch import torch
...@@ -88,6 +89,12 @@ if __name__ == "__main__": ...@@ -88,6 +89,12 @@ if __name__ == "__main__":
default="cora", default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').", help="Dataset name ('cora', 'citeseer', 'pubmed').",
) )
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args() args = parser.parse_args()
print(f"Training with DGL built-in GATConv module.") print(f"Training with DGL built-in GATConv module.")
...@@ -115,6 +122,12 @@ if __name__ == "__main__": ...@@ -115,6 +122,12 @@ if __name__ == "__main__":
out_size = data.num_classes out_size = data.num_classes
model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device) model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
# model training # model training
print("Training...") print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
......
...@@ -72,6 +72,12 @@ if __name__ == "__main__": ...@@ -72,6 +72,12 @@ if __name__ == "__main__":
default="cora", default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').", help="Dataset name ('cora', 'citeseer', 'pubmed').",
) )
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args() args = parser.parse_args()
print(f"Training with DGL built-in GraphConv module.") print(f"Training with DGL built-in GraphConv module.")
...@@ -99,6 +105,12 @@ if __name__ == "__main__": ...@@ -99,6 +105,12 @@ if __name__ == "__main__":
out_size = data.num_classes out_size = data.num_classes
model = GCN(in_size, 16, out_size).to(device) model = GCN(in_size, 16, out_size).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
# model training # model training
print("Training...") print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
......
...@@ -58,6 +58,7 @@ class SAGE(nn.Module): ...@@ -58,6 +58,7 @@ class SAGE(nn.Module):
y = torch.empty( y = torch.empty(
g.num_nodes(), g.num_nodes(),
self.hid_size if l != len(self.layers) - 1 else self.out_size, self.hid_size if l != len(self.layers) - 1 else self.out_size,
dtype=feat.dtype,
device=buffer_device, device=buffer_device,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
...@@ -171,6 +172,12 @@ if __name__ == "__main__": ...@@ -171,6 +172,12 @@ if __name__ == "__main__":
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, " help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.", "'puregpu' for pure-GPU training.",
) )
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args() args = parser.parse_args()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
args.mode = "cpu" args.mode = "cpu"
...@@ -189,6 +196,11 @@ if __name__ == "__main__": ...@@ -189,6 +196,11 @@ if __name__ == "__main__":
out_size = dataset.num_classes out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device) model = SAGE(in_size, 256, out_size).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
model = model.to(dtype=torch.bfloat16)
# model training # model training
print("Training...") print("Training...")
train(args, device, g, dataset, model, num_classes) train(args, device, g, dataset, model, num_classes)
......
import argparse import argparse
import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import torch import torch
...@@ -69,6 +70,12 @@ if __name__ == "__main__": ...@@ -69,6 +70,12 @@ if __name__ == "__main__":
default="cora", default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed')", help="Dataset name ('cora', 'citeseer', 'pubmed')",
) )
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args() args = parser.parse_args()
print(f"Training with DGL built-in GraphSage module") print(f"Training with DGL built-in GraphSage module")
...@@ -96,6 +103,12 @@ if __name__ == "__main__": ...@@ -96,6 +103,12 @@ if __name__ == "__main__":
out_size = data.num_classes out_size = data.num_classes
model = SAGE(in_size, 16, out_size).to(device) model = SAGE(in_size, 16, out_size).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
# model training # model training
print("Training...") print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
......
...@@ -21,6 +21,7 @@ def data_type_dict(): ...@@ -21,6 +21,7 @@ def data_type_dict():
"""Returns a dictionary from data type string to the data type. """Returns a dictionary from data type string to the data type.
The dictionary should include at least: The dictionary should include at least:
bfloat16
float16 float16
float32 float32
float64 float64
......
...@@ -18,6 +18,7 @@ if version.parse(th.__version__) < version.parse("1.12.0"): ...@@ -18,6 +18,7 @@ if version.parse(th.__version__) < version.parse("1.12.0"):
def data_type_dict(): def data_type_dict():
return { return {
"bfloat16": th.bfloat16,
"float16": th.float16, "float16": th.float16,
"float32": th.float32, "float32": th.float32,
"float64": th.float64, "float64": th.float64,
......
...@@ -30,6 +30,7 @@ def zerocopy_from_dlpack(dlpack_tensor): ...@@ -30,6 +30,7 @@ def zerocopy_from_dlpack(dlpack_tensor):
def data_type_dict(): def data_type_dict():
return { return {
"bfloat16": tf.bfloat16,
"float16": tf.float16, "float16": tf.float16,
"float32": tf.float32, "float32": tf.float32,
"float64": tf.float64, "float64": tf.float64,
......
...@@ -990,18 +990,29 @@ class Frame(MutableMapping): ...@@ -990,18 +990,29 @@ class Frame(MutableMapping):
F.float64, F.float64,
F.float32, F.float32,
F.float16, F.float16,
F.bfloat16,
], "'new_type' must be floating-point type: %s" % str(new_type) ], "'new_type' must be floating-point type: %s" % str(new_type)
newframe = self.clone() newframe = self.clone()
new_columns = {} new_columns = {}
for name, column in self._columns.items(): for name, column in self._columns.items():
dtype = column.dtype dtype = column.dtype
if dtype != new_type and dtype in [F.float64, F.float32, F.float16]: if dtype != new_type and dtype in [
F.float64,
F.float32,
F.float16,
F.bfloat16,
]:
new_columns[name] = column.astype(new_type) new_columns[name] = column.astype(new_type)
else: else:
new_columns[name] = column new_columns[name] = column
newframe._columns = new_columns newframe._columns = new_columns
return newframe return newframe
def bfloat16(self):
"""Return a new frame with all floating-point columns converted
to bfloat16"""
return self._astype_float(F.bfloat16)
def half(self): def half(self):
"""Return a new frame with all floating-point columns converted """Return a new frame with all floating-point columns converted
to half-precision (float16)""" to half-precision (float16)"""
......
...@@ -86,6 +86,7 @@ __all__ = [ ...@@ -86,6 +86,7 @@ __all__ = [
"random_walk_pe", "random_walk_pe",
"laplacian_pe", "laplacian_pe",
"lap_pe", "lap_pe",
"to_bfloat16",
"to_half", "to_half",
"to_float", "to_float",
"to_double", "to_double",
...@@ -3711,6 +3712,24 @@ def laplacian_pe(g, k, padding=False, return_eigval=False): ...@@ -3711,6 +3712,24 @@ def laplacian_pe(g, k, padding=False, return_eigval=False):
return lap_pe(g, k, padding, return_eigval) return lap_pe(g, k, padding, return_eigval)
def to_bfloat16(g):
r"""Cast this graph to use bfloat16 for any
floating-point edge and node feature data.
A shallow copy is returned so that the original graph is not modified.
Feature tensors that are not floating-point will not be modified.
Returns
-------
DGLGraph
Clone of graph with the feature data converted to float16.
"""
ret = copy.copy(g)
ret._edge_frames = [frame.bfloat16() for frame in ret._edge_frames]
ret._node_frames = [frame.bfloat16() for frame in ret._node_frames]
return ret
def to_half(g): def to_half(g):
r"""Cast this graph to use float16 (half-precision) for any r"""Cast this graph to use float16 (half-precision) for any
floating-point edge and node feature data. floating-point edge and node feature data.
......
...@@ -2443,7 +2443,7 @@ def test_dtype_cast(idtype): ...@@ -2443,7 +2443,7 @@ def test_dtype_cast(idtype):
def test_float_cast(): def test_float_cast():
for t in [F.float16, F.float32, F.float64]: for t in [F.bfloat16, F.float16, F.float32, F.float64]:
idtype = F.int32 idtype = F.int32
g = dgl.heterograph( g = dgl.heterograph(
{ {
...@@ -2469,6 +2469,7 @@ def test_float_cast(): ...@@ -2469,6 +2469,7 @@ def test_float_cast():
("c", F.float64), ("c", F.float64),
("d", F.int32), ("d", F.int32),
("e", F.int64), ("e", F.int64),
("f", F.bfloat16),
] ]
for name, type in dataNamesTypes: for name, type in dataNamesTypes:
g.nodes["user"].data[name] = F.copy_to( g.nodes["user"].data[name] = F.copy_to(
...@@ -2487,6 +2488,8 @@ def test_float_cast(): ...@@ -2487,6 +2488,8 @@ def test_float_cast():
F.tensor(pvalues, dtype=type), ctx=F.ctx() F.tensor(pvalues, dtype=type), ctx=F.ctx()
) )
if t == F.bfloat16:
g = dgl.transforms.functional.to_bfloat16(g)
if t == F.float16: if t == F.float16:
g = dgl.transforms.functional.to_half(g) g = dgl.transforms.functional.to_half(g)
if t == F.float32: if t == F.float32:
...@@ -2498,7 +2501,7 @@ def test_float_cast(): ...@@ -2498,7 +2501,7 @@ def test_float_cast():
# integer tensors shouldn't be converted # integer tensors shouldn't be converted
reqType = ( reqType = (
t t
if (origType in [F.float16, F.float32, F.float64]) if (origType in [F.bfloat16, F.float16, F.float32, F.float64])
else origType else origType
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment