Unverified Commit 10eb13e2 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add nn Layers (BF16) (#299)



* Add Linear layer (FP16)
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

- Add BF16 training example
- Add fp8_autocast (only supports non-fp8 for now)
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Remove FP8 stuff
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Simplify Linear layer forward
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add LayerNorm layer (BF16)
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add LayerNormLinear layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Store weights in BF16
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add LayerNormMLP layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add BF16 MNIST example
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Remove in-place cast for compatibility with Paddle AMP mechanism
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* README correction
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add Paddle op as a backend option
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix code format
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix dtype change between iterations
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Minor fixes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Move forward function out of base layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Use Paddle nvtx bindings
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent faefacb8
# Basic MNIST Example (BF16)
```bash
python test_single_gpu_mnist.py
python test_single_gpu_mnist.py --use-te # Linear layers from TransformerEngine
```
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MNIST example of Transformer Engine Paddle"""
import argparse
import unittest
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.vision.transforms import Normalize
from paddle.io import DataLoader
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
import transformer_engine.paddle as te
class Net(nn.Layer):
"""Simple network used to train on MNIST"""
def __init__(self, use_te=False):
super().__init__()
self.conv1 = nn.Conv2D(1, 32, 3, 1)
self.conv2 = nn.Conv2D(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
if use_te:
self.fc1 = te.Linear(9216, 128)
self.fc2 = te.Linear(128, 16)
else:
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 16)
self.fc3 = nn.Linear(16, 10)
def forward(self, x):
"""FWD"""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = paddle.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def train(args, model, train_loader, optimizer, epoch):
"""Training function."""
model.train()
for batch_id, (data, labels) in enumerate(train_loader):
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
outputs = model(data)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
optimizer.clear_gradients()
if batch_id % args.log_interval == 0:
print(f"Train Epoch: {epoch} "
f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_id / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}")
if args.dry_run:
return loss.item()
return loss.item()
def evaluate(model, test_loader, epoch):
"""Testing function."""
model.eval()
metric = Accuracy()
metric.reset()
with paddle.no_grad():
for data, labels in test_loader:
with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
outputs = model(data)
acc = metric.compute(outputs, labels)
metric.update(acc)
print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}")
return metric.accumulate()
def mnist_parser(args):
"""Parse training settings"""
parser = argparse.ArgumentParser(description="Paddle MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=14,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
metavar="LR",
help="learning rate (default: 0.001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--use-te",
action="store_true",
default=False,
help="Use Transformer Engine")
return parser.parse_args(args)
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
paddle.seed(args.seed)
# Load MNIST dataset
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
train_dataset = MNIST(mode='train', transform=transform)
val_dataset = MNIST(mode='test', transform=transform)
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size)
# Define model and optimizer
model = Net(use_te=args.use_te)
optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
# Cast model to BF16
model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch)
acc = evaluate(model, val_loader, epoch)
return loss, acc
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
@classmethod
def setUpClass(cls):
"""Run MNIST without Transformer Engine"""
cls.args = mnist_parser(["--epochs", "5"])
@staticmethod
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.5
desired_test_accuracy = 0.98
assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy
@unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0),
"BF16 MNIST example requires Ampere+ GPU")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
......@@ -6,3 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""
import pytest
from utils import assert_allclose
import paddle
import transformer_engine.paddle as te
LINEAR_CASES = [(16, 16, 32), (32, 32, 64), (64, 128, 256)]
NORM_CASES = [(16, 32), (256, 1024)]
MLP_CASES = [(32, 32, 32), (64, 256, 512)]
def calc_output_and_grad(layer, x, dy):
"""
Calculate forward and backward pass
"""
inp = paddle.to_tensor(x)
inp.stop_gradient = x.stop_gradient
y = layer(inp)
y.backward(dy)
return y, inp.grad if not inp.stop_gradient else None
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
def test_linear_bf16(bs, in_features, out_features):
"""
Test BF16 Linear
"""
rtol = 1e-2
atol = 1e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype='bfloat16')
input_tensor.stop_gradient = False
grad_out = paddle.uniform(shape=(bs, out_features), dtype='bfloat16')
paddle.set_default_dtype("bfloat16")
layer_te = te.Linear(in_features, out_features)
layer_pd = te.Linear(in_features, out_features, backend='paddle')
layer_pd.weight.copy_(layer_te.weight.T, True)
layer_pd.bias.copy_(layer_te.bias, True)
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
@pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
def test_layernorm_bf16(bs, hidden_size):
"""
Test BF16 LayerNorm
"""
eps = 1e-3
rtol = 1e-2
atol = 1e-2
x = paddle.uniform(shape=(bs, hidden_size), dtype='bfloat16')
x.stop_gradient = False
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype='bfloat16')
paddle.set_default_dtype("bfloat16")
layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps)
layer_pd = te.LayerNorm(hidden_size=hidden_size, eps=eps, backend='paddle')
layer_pd.weight.copy_(layer_te.weight, True)
layer_pd.bias.copy_(layer_te.bias, True)
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out)
out, grad_input = calc_output_and_grad(layer_te, x, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
def test_layernorm_linear_bf16(bs, in_features, out_features):
"""
Test BF16 LayerNormLinear Layer
"""
paddle.set_default_dtype("bfloat16")
rtol = 1e-2
atol = 1e-2
input_tensor = paddle.uniform(shape=(bs, in_features), dtype='bfloat16')
input_tensor.stop_gradient = False
grad_out = paddle.uniform(shape=(bs, out_features), dtype='bfloat16')
eps = 1e-3
layer_te = te.LayerNormLinear(
in_features=in_features,
out_features=out_features,
eps=eps,
)
layer_pd = te.LayerNormLinear(in_features=in_features,
out_features=out_features,
eps=eps,
backend='paddle')
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.weight.copy_(layer_te.weight.T, True)
layer_pd.bias.copy_(layer_te.bias, True)
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 Linear requires Ampere+ GPU")
@pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', MLP_CASES)
def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size):
"""
Test BF16 LayerNormMLP Layer
"""
paddle.set_default_dtype("bfloat16")
rtol = 5e-2
atol = 5e-2
input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype='bfloat16')
input_tensor.stop_gradient = False
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype='bfloat16')
eps = 1e-3
layer_te = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
)
layer_pd = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
eps=eps,
backend='paddle',
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
layer_pd.ln_bias.copy_(layer_te.ln_bias, True)
layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True)
layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True)
layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True)
layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True)
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol)
assert_allclose(layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
......@@ -3,8 +3,10 @@
# See LICENSE for license information.
"""Utils for testing"""
import paddle
import numpy as np
import paddle
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
......
......@@ -3,4 +3,4 @@
# See LICENSE for license information.
"""Transformer Engine bindings for Paddle"""
from .cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_from_fp8
from .layer import Linear, LayerNorm, LayerNormLinear, LayerNormMLP
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Layer level Paddle APIs"""
from .layernorm import LayerNorm
from .layernorm_linear import LayerNormLinear
from .layernorm_mlp import LayerNormMLP
from .linear import Linear
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Base modules and utilities for TransformerEngine Paddle API"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
import paddle
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer
from ..profile import nvtx_range
_cublas_workspace = None
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if paddle.device.cuda.get_device_capability()[0] >= 9:
return 33_554_432
return 4_194_304
def get_workspace() -> paddle.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = paddle.empty(
[get_cublas_workspace_size_bytes()],
dtype='uint8',
)
return _cublas_workspace
class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
"""Base TE Layer."""
def __init__(self) -> None:
super().__init__()
assert 'gpu' in paddle.device.get_device(), "TransformerEngine needs CUDA."
def set_activation_dtype(self, inp: paddle.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`paddle.amp.auto_cast`) gets highest priority
tracer = _dygraph_tracer()
if tracer and tracer._amp_level != core.AmpLevel.O0:
if tracer._amp_dtype == 'float32':
self.activation_dtype = paddle.float32
elif tracer._amp_dtype == 'bfloat16':
self.activation_dtype = paddle.bfloat16
elif tracer._amp_dtype == 'float16':
self.activation_dtype = paddle.float16
else:
raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.")
return
# All checks after this have already been performed once, thus skip
# We assume that user doesn't change input types across iterations
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
return
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}")
self.activation_dtype = dtype
@contextmanager
def prepare_forward(
self,
inp: paddle.Tensor,
) -> None:
"""
Checks and prep for FWD.
"""
self.set_activation_dtype(inp)
with nvtx_range(self.__class__.__name__ + " forward"):
yield inp
@abstractmethod
def forward(self):
"""Needs override."""
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""
import os
from typing import Union, Tuple
import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from ..constants import TE_DType
from ..cpp_extensions import layernorm_fwd, layernorm_bwd
__all__ = ["LayerNorm"]
class _LayerNorm(paddle.autograd.PyLayer):
"""TE Non-FP8 LayerNorm"""
@staticmethod
def forward(
ctx,
inp: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.reshape((-1, in_features))
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps, TE_DType[inp.dtype],
fwd_ln_sm_margin, zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dx = not inp.stop_gradient
ctx.requires_dw = not ln_weight.stop_gradient
ctx.requires_dbias = not ln_bias.stop_gradient
return ln_out.reshape(inp.shape)
@staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensor()
d_ln_out = grad_output.reshape(inputmat.shape)
dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None,
dgamma if ctx.requires_dw else None,
dbeta if ctx.requires_dbias else None,
)
class LayerNorm(paddle.nn.Layer):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
zero_centered_gamma: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.backend = backend
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
if not self._weight_attr:
self._weight_attr = paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0))
self._bias_attr = bias_attr
if self._bias_attr is False:
self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0), trainable=False)
self.weight = self.create_parameter(
shape=[hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.bias = self.create_parameter(
shape=[hidden_size],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor:
"""LayerNorm FWD"""
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.zero_centered_gamma)
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
return F.layer_norm(x=inp,
normalized_shape=inp.shape[1:],
weight=self.weight,
bias=self.bias,
epsilon=self.eps)
def forward(self, *args, **kwargs):
"""forward"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LayerNormLinear API"""
import os
from typing import Union, Tuple
import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from ..cpp_extensions import (
gemm,
layernorm_fwd,
layernorm_bwd,
)
from .base import get_workspace, TransformerEngineBaseLayer
from ..constants import TE_DType
from ..utils import cast_if_needed
__all__ = ["LayerNormLinear"]
class _LayerNormLinear(paddle.autograd.PyLayer):
"""TE implementation of non-FP8 LayerNormLinear"""
@staticmethod
def forward(
ctx,
inp: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
weight: paddle.Tensor,
bias: Union[paddle.Tensor, None],
use_bias: bool,
eps: float,
activation_dtype: paddle.dtype,
return_layernorm_output: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features))
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
out, _, _ = gemm(
weight,
ln_out,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
)
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
ln_out,
)
ctx.activation_dtype = activation_dtype
ctx.use_bias = use_bias
ctx.inp_shape = inp.shape
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dgrad = not inp.stop_gradient
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
if return_layernorm_output:
return out, ln_out.reshape(inp.shape)
return out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[paddle.Tensor,
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
ln_out,
) = ctx.saved_tensor()
grad_output = grad_outputs[0]
# Dgrad
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Wgrad
if not weight.stop_gradient:
wgrad, grad_bias, _ = gemm(
ln_out,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
)
# LayerNorm gradient
d_ln_out = dgrad.reshape(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].reshape(d_ln_out.shape)
dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma)
if not ctx.use_bias:
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if not weight.stop_gradient else None,
)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if not weight.stop_gradient else None,
grad_bias,
)
class LayerNormLinear(TransformerEngineBaseLayer):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.eps = eps
self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma
self.backend = backend
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[in_features],
attr=paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0)),
dtype=self._dtype,
is_bias=False,
)
self.ln_bias = self.create_parameter(
shape=[in_features],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
# Linear weights
self.weight = self.create_parameter(
shape=[out_features, in_features]
if self.backend == 'transformer_engine' else [in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.has_bias = self._bias_attr is not False
if self.has_bias:
self.bias = self.create_parameter(
shape=[out_features],
attr=self._bias_attr if self._bias_attr is not None else paddle.ParamAttr(
initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
else:
self.bias = None
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(
self,
inp: paddle.Tensor,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
"""
with self.prepare_forward(inp) as inp:
out = _LayerNormLinear.apply(
cast_if_needed(inp, self.activation_dtype),
cast_if_needed(self.ln_weight, self.activation_dtype),
cast_if_needed(self.ln_bias, self.activation_dtype),
cast_if_needed(self.weight, self.activation_dtype),
cast_if_needed(self.bias, self.activation_dtype),
self.has_bias,
self.eps,
self.activation_dtype,
self.return_layernorm_output,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
)
if self.return_layernorm_output:
out, ln_out = out
return out, ln_out
return out
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[1:],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
out = F.linear(ln_out, self.weight, self.bias)
if self.return_layernorm_output:
return out, ln_out
return out
def forward(self, *args, **kwargs):
"""forward"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LayerNormMLP API"""
import os
from typing import Union, Tuple
import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from ..cpp_extensions import (
gemm,
layernorm_fwd,
layernorm_bwd,
)
from .base import get_workspace, TransformerEngineBaseLayer
from ..constants import TE_DType
from ..utils import cast_if_needed, get_paddle_act_func
__all__ = ["LayerNormMLP"]
class _LayerNormMLP(paddle.autograd.PyLayer):
"""TE implementation of non-FP8 LayerNormMLP"""
@staticmethod
def forward(
ctx,
inp: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
fc1_weight: paddle.Tensor,
fc1_bias: Union[paddle.Tensor, None],
use_fc1_bias: bool,
fc2_weight: paddle.Tensor,
fc2_bias: Union[paddle.Tensor, None],
use_fc2_bias: bool,
eps: float,
activation_dtype: paddle.dtype,
return_layernorm_output: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features))
# only support gelu for now
assert activation == 'gelu'
# LN FWD
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
# FC1 + GeLU
gelu_out, _, fc1_out = gemm(
fc1_weight,
ln_out,
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=use_fc1_bias,
gelu=(activation == 'gelu'),
)
# FC2
fc2_out, _, _ = gemm(
fc2_weight,
gelu_out,
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_fc2_bias,
)
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc2_weight,
)
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
ctx.inp_shape = inp.shape
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dgrad = not inp.stop_gradient
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
if return_layernorm_output:
return fc2_out, ln_out.reshape(inp.shape)
return fc2_out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[paddle.Tensor,
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc2_weight,
) = ctx.saved_tensor()
# grad_fc2_out
grad_output = grad_outputs[0]
# FC2 Dgrad + dGELU
dgelu, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=(ctx.activation == 'gelu'),
gelu_input=fc1_out,
grad=True,
)
# FC2 Wgrad
if not fc2_weight.stop_gradient:
fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_fc2_bias,
)
# For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM
# and will not be calculated in case wgrad is not required.
if fc1_weight.stop_gradient:
fc1_bias_grad = dgelu.sum(axis=0)
# FC1 DGRAD
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# FC1 Wgrad
if not fc1_weight.stop_gradient:
fc1_wgrad, fc1_bias_grad, _ = gemm(
ln_out,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_fc1_bias,
)
# LayerNorm gradient
d_ln_out = fc1_dgrad.reshape(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].reshape(d_ln_out.shape)
dxmat, dgamma, dbeta = layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma)
fc1_bias_grad_out = (fc1_bias_grad,) if ctx.use_fc1_bias else ()
fc2_bias_grad_out = (fc2_bias_grad,) if ctx.use_fc2_bias else ()
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
fc1_wgrad if not fc1_weight.stop_gradient else None,
*fc1_bias_grad_out,
fc2_wgrad if not fc2_weight.stop_gradient else None,
*fc2_bias_grad_out,
)
class LayerNormMLP(TransformerEngineBaseLayer):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
activation: str = "gelu",
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.eps = eps
self.activation = activation
self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma
self.backend = backend
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0)),
dtype=self._dtype,
is_bias=False,
)
self.ln_bias = self.create_parameter(
shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
# FC1 weights
self.fc1_weight = self.create_parameter(
shape=[self.ffn_hidden_size, self.hidden_size]
if self.backend == 'transformer_engine' else [self.hidden_size, self.ffn_hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.has_bias = self._bias_attr is not False
if self._bias_attr is None:
self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0))
if self.has_bias:
self.fc1_bias = self.create_parameter(
shape=[self.ffn_hidden_size],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
else:
self.fc1_bias = None
# FC2 weights
self.fc2_weight = self.create_parameter(
shape=[self.hidden_size, self.ffn_hidden_size]
if self.backend == 'transformer_engine' else [self.ffn_hidden_size, self.hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
if self.has_bias:
self.fc2_bias = self.create_parameter(
shape=[self.hidden_size],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
else:
self.fc2_bias = None
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(
self,
inp: paddle.Tensor,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
"""
with self.prepare_forward(inp) as inp:
out = _LayerNormMLP.apply(
cast_if_needed(inp, self.activation_dtype),
cast_if_needed(self.ln_weight, self.activation_dtype),
cast_if_needed(self.ln_bias, self.activation_dtype),
cast_if_needed(self.fc1_weight, self.activation_dtype),
cast_if_needed(self.fc1_bias, self.activation_dtype),
self.has_bias,
cast_if_needed(self.fc2_weight, self.activation_dtype),
cast_if_needed(self.fc2_bias, self.activation_dtype),
self.has_bias,
self.eps,
self.activation_dtype,
self.return_layernorm_output,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
)
if self.return_layernorm_output:
out, ln_out = out
return out, ln_out
return out
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Calls Paddle OP"""
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support LayerNorm with zero-centered scale.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[1:],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias)
act_func = get_paddle_act_func(self.activation)
act_out = act_func(fc1_out)
out = F.linear(act_out, self.fc2_weight, self.fc2_bias)
if self.return_layernorm_output:
return out, ln_out
return out
def forward(self, *args, **kwargs):
"""forward"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""
from typing import Union, Tuple
import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from .base import TransformerEngineBaseLayer, get_workspace
from ..cpp_extensions import gemm
from ..utils import cast_if_needed
__all__ = ["Linear"]
class _Linear(paddle.autograd.PyLayer):
"""TE implementation of non-FP8 Linear"""
@staticmethod
def forward(
ctx,
weight: paddle.Tensor,
inp: paddle.Tensor,
bias: paddle.Tensor,
use_bias: bool,
activation_dtype: paddle.dtype,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features))
out, _, _ = gemm(
weight,
inputmat,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
)
ctx.save_for_backward(
inputmat,
weight,
)
ctx.activation_dtype = activation_dtype
ctx.use_bias = use_bias
ctx.inp_shape = inp.shape
ctx.requires_dgrad = not inp.stop_gradient
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
@staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, weight = ctx.saved_tensor()
if ctx.requires_dgrad:
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if not weight.stop_gradient:
wgrad, grad_bias, _ = gemm(
inputmat,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
)
if not ctx.use_bias:
return (
wgrad if not weight.stop_gradient else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
)
return (
wgrad if not weight.stop_gradient else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
)
class Linear(TransformerEngineBaseLayer):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
"""
def __init__(
self,
in_features: int,
out_features: int,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.backend = backend
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# TE linear weight is in column major
self.weight = self.create_parameter(
shape=[out_features, in_features]
if self.backend == 'transformer_engine' else [in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.has_bias = self._bias_attr is not False
if self.has_bias:
self.bias = self.create_parameter(
shape=[out_features],
attr=self._bias_attr if self._bias_attr is not None else paddle.ParamAttr(
initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
else:
self.bias = None
def _te_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the linear transformation to the input.
"""
with self.prepare_forward(inp) as inp:
out = _Linear.apply(
cast_if_needed(self.weight, self.activation_dtype),
cast_if_needed(inp, self.activation_dtype),
cast_if_needed(self.bias, self.activation_dtype),
self.has_bias,
self.activation_dtype,
)
return out
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Calls Paddle OP"""
return F.linear(inp, self.weight, self.bias)
def forward(self, *args, **kwargs):
"""forward"""
if self.backend == 'transformer_engine':
return self._te_forward(*args, **kwargs)
if self.backend == 'paddle':
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} is not supported.")
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for profiling"""
from contextlib import contextmanager
from paddle.fluid import core
@contextmanager
def nvtx_range(msg):
"""Context to insert NVTX"""
core.nvprof_nvtx_push(msg)
yield
core.nvprof_nvtx_pop()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
from typing import Union
import paddle
import paddle.nn.functional as F
def cast_if_needed(tensor: Union[paddle.Tensor, None],
dtype: paddle.dtype) -> Union[paddle.Tensor, None]:
"""Cast tensor to dtype"""
return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype)
def get_paddle_act_func(activation):
"""Get paddle activation function"""
funcs = {
'gelu': F.gelu,
'relu': F.relu,
}
if activation not in funcs:
raise "Activation type " + activation + " is not supported."
return funcs[activation]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment