Unverified Commit 496b8fdd authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

[JAX] add multiprocessing example and improve debugging message (#198)



* add mp example
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update doc-string
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* better FP8 checker
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace te.* with te.flax* to remove deprecated warning
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse os.environ
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix typo
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/test_multiprocessing_encoder.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove cuda-python
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* adjust readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix cpp lint

fix issue of "Could not find a newline character at the end of the file."
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix AssertionError: 1 GPU per process
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace tfds with datasets

The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container.
The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community.
However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX.
It triggers random errors at JAX initialization depending on different versions, and make CI unstable.
Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets.
See "nvbugs 4039266" for more details.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix input sharding

Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually.
Using DP=4, TP=2 as an example, the device mesh looks like:

mesh.device_ids = [[0, 1],
                   [2, 3],
                   [4, 5],
                   [6, 7]]

Assume that the process ID is mapped to GPU ID.
The process 0 and process 1 are grouped for model parallelism,
process 2 and process 3 are grouped together too, and so on.

The process 0 and process 1 need to share the same micro-batch in the training step,
process 0 and process 2, 4, and 6 have different micro-batch.

Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup
needed arguments for jax.make_array_from_single_device_arrays).
The process 0 and process 1 take the first quarter,
process 2 and process 3 take the second quarter, and so on.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor UT for multiprocess example

Use Python `multiprocessing` to test the multiprocessing example,
if the system has multiple GPU. 1 GPU per process.

Because `jax.distributed.initialize` must be called before any other JAX or Flax API,
GPU info cannot be queried by calling jax.local_devices() in TestEncoder.
Thus, `unittest_query_gpu()` forks another process to query number of GPUs and
FP8 capability.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse arg `--num-gpu`

JAX doesn't have an API to setup number of GPU used in SPMD mode.
The only way is to use `CUDA_VISIBLE_DEVICES` for now.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix typo
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix ut
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* simplify the mask setting
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* increase batch-size for multigpu example

The batch-size 64 is too small to be partitioned for 8xH100.
If batch-size is 64, the GEMM shape is 256x8192x8 per GPU.
The 8 is too small for FP8 GEMM kernel, and
cuBLASLt will throw "Failed to query heuristics".
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix downloading mnist error

To download MNIST via `huggingface datasets`, it requires Pillow.
Otherwise, it throws `An error occurred while generating the
dataset`
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 441fa968
...@@ -35,7 +35,14 @@ python test_single_gpu_encoder.py --use-fp8 ...@@ -35,7 +35,14 @@ python test_single_gpu_encoder.py --use-fp8
6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. 6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
7. The `train_step` and `eval_step` also needs to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. 7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then:
```sh
export CUDA_VISIBLE_DEVICES=0,1,2,3
python test_multigpu_encoder.py
```
Please refer to [CUDA Environment Variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables) for more details.
### Run ### ### Run ###
...@@ -60,6 +67,10 @@ python test_multigpu_encoder.py --use-fp8 ...@@ -60,6 +67,10 @@ python test_multigpu_encoder.py --use-fp8
import os import os
os.environ['XLA_FLAGS'] = "--xla_dump_hlo_as_proto --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=<path to store XLA HLO>" os.environ['XLA_FLAGS'] = "--xla_dump_hlo_as_proto --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=<path to store XLA HLO>"
``` ```
5. If the model parallelism example is run in the container, it is recommended to add `--ipc=host` in launch arguments. Otherwise, it might trigger UCX errors.
```sh
docker run --gpus=all --ipc=host ...
```
### Run ### ### Run ###
...@@ -67,3 +78,62 @@ python test_multigpu_encoder.py --use-fp8 ...@@ -67,3 +78,62 @@ python test_multigpu_encoder.py --use-fp8
python test_model_parallel_encoder.py python test_model_parallel_encoder.py
python test_model_parallel_encoder.py --use-fp8 python test_model_parallel_encoder.py --use-fp8
``` ```
## Multiple Processes with Model Parallelism ##
1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process.
2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
3. The quick way to check system topology is to use `nvidia-smi`, for example:
```sh
$ nvidia-smi topo -mp
CPU Affinity NUMA Affinity
GPU0 48-63,176-191 3
GPU1 48-63,176-191 3
GPU2 16-31,144-159 1
GPU3 16-31,144-159 1
GPU4 112-127,240-255 7
GPU5 112-127,240-255 7
GPU6 80-95,208-223 5
GPU7 80-95,208-223 5
```
4. It is recommended to set the environment variable `CUDA_DEVICE_ORDER` to `PCI_BUS_ID` before running the example with the affinity setting. To ensure that the device order is aligned between CUDA and `nvidia-smi`. Please refer to [CUDA Environment Variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables) for more details.
5. `jax.distributed.initialize` must be called before any other JAX or Flax API, otherwise `jax.local_devices` will be incorrect. `jax.distributed.shutdown` should be the last API call.
6. Unlike SPMD, the input tensor must be sharded manually and be wrapped by `jax.make_array_from_single_device_arrays`. Otherwise, the sharding will be incorrect. Using DP=4, TP=2 as an example, the device mesh looks like:
```python
mesh.device_ids = [[0, 1],
[2, 3],
[4, 5],
[6, 7]]
```
Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, the process 2 and process 3 are grouped together too, and so on. Thus, process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch.
### Run ###
If the system has 8 GPUs, the basic commands are:
```bash
python test_multiprocessing_encoder.py --num-process 8 --process-id 0 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 1 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 2 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 3 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 4 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 5 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 6 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 7 &
```
The correct setting for hardware affinity is system dependent. Taking the above system topology as an example, the command can be:
```bash
numactl --cpunodebind=48 --membind=3 python test_multiprocessing_encoder.py --num-process 8 --process-id 0 &
numactl --cpunodebind=49 --membind=3 python test_multiprocessing_encoder.py --num-process 8 --process-id 1 &
numactl --cpunodebind=16 --membind=1 python test_multiprocessing_encoder.py --num-process 8 --process-id 2 &
numactl --cpunodebind=17 --membind=1 python test_multiprocessing_encoder.py --num-process 8 --process-id 3 &
numactl --cpunodebind=112 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 4 &
numactl --cpunodebind=113 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 5 &
numactl --cpunodebind=80 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 6 &
numactl --cpunodebind=81 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 7 &
```
\ No newline at end of file
datasets
flax flax
nltk nltk
optax optax
tensorflow-datasets
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" Encoder training on multi-GPU with tesnor parallelism""" """Encoder training on multi-GPU with tesnor parallelism"""
import argparse import argparse
import unittest import unittest
from functools import partial from functools import partial
...@@ -11,10 +11,10 @@ import jax.numpy as jnp ...@@ -11,10 +11,10 @@ import jax.numpy as jnp
import nltk import nltk
import numpy as np import numpy as np
import optax import optax
import tensorflow_datasets as tfds from datasets import load_dataset
from cuda import cudart
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
...@@ -31,26 +31,6 @@ DROPOUT_KEY = 'dropout' ...@@ -31,26 +31,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng' INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
...@@ -66,7 +46,7 @@ class Net(nn.Module): ...@@ -66,7 +46,7 @@ class Net(nn.Module):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER, layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
...@@ -177,12 +157,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -177,12 +157,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt') nltk.download('punkt')
dataset_size = len(dataset['sentence']) dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8")) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens): for i, word in enumerate(tokens):
if i >= max_seq_len: if i >= max_seq_len:
...@@ -195,26 +174,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -195,26 +174,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
mask_1d[0, i] = 1 seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d) mask_2d[:seq_len, :seq_len] = 0
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output new_dataset = {
dataset['label'] = dataset['label'].astype(np.float32) 'sentence': output,
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) 'label': dataset['label'].astype(np.float32),
return dataset, vocab, word_id 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len): def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory.""" """Load GLUE train and test datasets into memory."""
vocab = {} vocab = {}
word_id = 0 word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -238,7 +222,7 @@ def get_params_pspec(sharding_rules, abs_var_collect): ...@@ -238,7 +222,7 @@ def get_params_pspec(sharding_rules, abs_var_collect):
return jax.sharding.PartitionSpec(*partitions) return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes)) params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec return params_pspec
...@@ -257,14 +241,12 @@ def get_state_pspec(state, params_pspec): ...@@ -257,14 +241,12 @@ def get_state_pspec(state, params_pspec):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
check_num_gpu(args.num_gpu) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
num_gpu = jax.local_device_count()
num_gpu_tp = 2 num_gpu_tp = 2
if args.num_gpu % num_gpu_tp == 0: if num_gpu % num_gpu_tp == 0:
num_gpu_dp = args.num_gpu // num_gpu_tp num_gpu_dp = num_gpu // num_gpu_tp
else: else:
num_gpu_dp = 1 num_gpu_dp = 1
num_gpu_tp = 1 num_gpu_tp = 1
...@@ -287,14 +269,13 @@ def train_and_evaluate(args): ...@@ -287,14 +269,13 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8, with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
...@@ -356,13 +337,6 @@ def train_and_evaluate(args): ...@@ -356,13 +337,6 @@ def train_and_evaluate(args):
def encoder_parser(args): def encoder_parser(args):
"""Training settings.""" """Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example") parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
...@@ -416,20 +390,19 @@ def encoder_parser(args): ...@@ -416,20 +390,19 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
num_gpu = len(jax.local_devices()) cls.args = encoder_parser(["--epochs", "3"])
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self): def test_te_fp8(self):
"""Test Transformer Engine with FP8""" """Test Transformer Engine with FP8"""
self.args.use_fp8 = True self.args.use_fp8 = True
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" Encoder training on multi-GPU with data parallelism""" """Encoder training on multi-GPU with data parallelism"""
import argparse import argparse
import unittest import unittest
from functools import partial from functools import partial
...@@ -11,10 +11,10 @@ import jax.numpy as jnp ...@@ -11,10 +11,10 @@ import jax.numpy as jnp
import nltk import nltk
import numpy as np import numpy as np
import optax import optax
import tensorflow_datasets as tfds from datasets import load_dataset
from cuda import cudart
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
...@@ -28,26 +28,6 @@ DROPOUT_KEY = 'dropout' ...@@ -28,26 +28,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng' INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
...@@ -63,7 +43,7 @@ class Net(nn.Module): ...@@ -63,7 +43,7 @@ class Net(nn.Module):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER, layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
...@@ -168,12 +148,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -168,12 +148,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt') nltk.download('punkt')
dataset_size = len(dataset['sentence']) dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8")) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens): for i, word in enumerate(tokens):
if i >= max_seq_len: if i >= max_seq_len:
...@@ -186,26 +165,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -186,26 +165,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
mask_1d[0, i] = 1 seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d) mask_2d[:seq_len, :seq_len] = 0
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output new_dataset = {
dataset['label'] = dataset['label'].astype(np.float32) 'sentence': output,
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) 'label': dataset['label'].astype(np.float32),
return dataset, vocab, word_id 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len): def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory.""" """Load GLUE train and test datasets into memory."""
vocab = {} vocab = {}
word_id = 0 word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -229,7 +213,7 @@ def get_params_pspec(sharding_rules, abs_var_collect): ...@@ -229,7 +213,7 @@ def get_params_pspec(sharding_rules, abs_var_collect):
return jax.sharding.PartitionSpec(*partitions) return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes)) params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec}) params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec return params_pspec
...@@ -248,15 +232,14 @@ def get_state_pspec(state, params_pspec): ...@@ -248,15 +232,14 @@ def get_state_pspec(state, params_pspec):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
check_num_gpu(args.num_gpu) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
assert args.batch_size % args.num_gpu == 0, f"Batch size needs to be multiple of {args.num_gpu}"
assert args.test_batch_size % args.num_gpu == 0, \
f"Test batch size needs to be multiple of {args.num_gpu}"
if args.use_fp8: num_gpu = jax.local_device_count()
assert gpu_has_fp8(), "GPU needs to support FP8." assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, \
f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((args.num_gpu,)) device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
...@@ -269,13 +252,12 @@ def train_and_evaluate(args): ...@@ -269,13 +252,12 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)): with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te.extend_logical_axis_rules(tuple()) sharding_rules = te.flax.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
...@@ -337,26 +319,19 @@ def train_and_evaluate(args): ...@@ -337,26 +319,19 @@ def train_and_evaluate(args):
def encoder_parser(args): def encoder_parser(args):
"""Training settings.""" """Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example") parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
default=64, default=128,
metavar="N", metavar="N",
help="input batch size for training (default: 64)", help="input batch size for training (default: 128)",
) )
parser.add_argument( parser.add_argument(
"--test-batch-size", "--test-batch-size",
type=int, type=int,
default=64, default=128,
metavar="N", metavar="N",
help="input batch size for testing (default: 64)", help="input batch size for testing (default: 128)",
) )
parser.add_argument( parser.add_argument(
"--max-seq-len", "--max-seq-len",
...@@ -397,25 +372,24 @@ def encoder_parser(args): ...@@ -397,25 +372,24 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
num_gpu = len(jax.local_devices()) cls.args = encoder_parser(["--epochs", "3"])
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.49 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self): def test_te_fp8(self):
"""Test Transformer Engine with FP8""" """Test Transformer Engine with FP8"""
self.args.use_fp8 = True self.args.use_fp8 = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.49 and actual[1] > 0.76
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse
import multiprocessing as mp
import os
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
import transformer_engine.jax as te
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
DEVICE_DP_AXIS = 'data'
DEVICE_TP_AXIS = 'model'
NAMED_BROADCAST_AXIS = 'my_broadcast_axis'
NAMED_TP_AXIS = 'my_tp_axis'
PARAMS_KEY = 'params'
PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)
x = te.flax.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
def valid_shard_size(total_size, batch_size, dp_size, tp_size):
"""Get sharded input shape"""
global_batch_size = dp_size * batch_size
num_steps = total_size // global_batch_size
valid_size = num_steps * global_batch_size
gpu_id = jax.local_devices()[0].id
tp_group_id = gpu_id // tp_size
return valid_size, global_batch_size, num_steps, tp_group_id
def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False):
"""Generate needed args for jax.make_array_from_single_device_arrays"""
inputs = jnp.asarray(dataset)
total_input_size = len(inputs)
(dp_size, tp_size) = mesh.device_ids.shape
valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size(
total_input_size, batch_size, dp_size, tp_size)
inputs = inputs[:valid_input_size] # skip incomplete batch
single_input_shape = inputs.shape[1:]
global_input_shape = (global_batch_size, *single_input_shape)
named_sharding = jax.sharding.NamedSharding(mesh, pspec)
if enable_partition:
inputs = inputs.reshape(dp_size, num_steps, batch_size, *single_input_shape)
inputs = inputs[tp_group_id]
return global_input_shape, named_sharding, inputs
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn, mesh,
inputs_pspec, masks_pspec, labels_pspec):
"""Train for a single epoch."""
total_batch_size = len(train_ds['sentence'])
(dp_size, tp_size) = mesh.device_ids.shape
valid_size, _, num_steps, tp_group_id = valid_shard_size(total_batch_size, batch_size, dp_size,
tp_size)
perms = jax.random.permutation(rngs[INPUT_KEY], valid_size)
perms = perms.reshape(dp_size, num_steps, batch_size)
perms = perms[tp_group_id]
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
train_ds['sentence'], batch_size, mesh, inputs_pspec)
global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(
train_ds['mask'], batch_size, mesh, masks_pspec)
global_label_shape, label_named_sharding, label = shard_array_wrapper(
train_ds['label'], batch_size, mesh, labels_pspec)
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_input = sentence[perm, ...]
batch_mask = mask[perm, ...]
batch_label = label[perm, ...]
shard_input = jax.make_array_from_single_device_arrays(global_input_shape,
input_named_sharding, [batch_input])
shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape,
mask_named_sharding, [batch_mask])
shard_label = jax.make_array_from_single_device_arrays(global_label_shape,
label_named_sharding, [batch_label])
state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label,
var_collect, rngs, use_fp8)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec,
labels_pspec):
"""Evaluation loop."""
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(test_ds['sentence'],
batch_size,
mesh,
inputs_pspec,
enable_partition=True)
global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(test_ds['mask'],
batch_size,
mesh,
masks_pspec,
enable_partition=True)
global_label_shape, label_named_sharding, label = shard_array_wrapper(test_ds['label'],
batch_size,
mesh,
labels_pspec,
enable_partition=True)
all_loss = []
all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
shard_input = jax.make_array_from_single_device_arrays(global_input_shape,
input_named_sharding, [batch_input])
shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape,
mask_named_sharding, [batch_mask])
shard_label = jax.make_array_from_single_device_arrays(global_label_shape,
label_named_sharding, [batch_label])
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect):
"""Refer params to create params partition spec"""
rules_dict = {}
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec
def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
return state_pspec
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
jax.distributed.initialize(coordinator_address=args.coordinator_address,
num_processes=args.num_process,
process_id=args.process_id,
local_device_ids=args.process_id)
assert jax.local_device_count() == 1, "1 GPU per process"
num_gpu_tp = 2
if args.num_process % num_gpu_tp == 0:
num_gpu_dp = args.num_process // num_gpu_tp
else:
assert args.num_process == 1, "number of processes should be multiple of 2, or 1"
num_gpu_dp = 1
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert args.test_batch_size % num_gpu_dp == 0, \
f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh,
axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)) as shard_mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step, shard_mesh, inputs_pspec, masks_pspec, labels_pspec)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step, shard_mesh,
inputs_pspec, masks_pspec, labels_pspec)
if args.process_id == 0:
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
jax.distributed.shutdown()
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
parser.add_argument("--coordinator-address",
type=str,
default="127.0.0.1:1234",
help="the IP address of process 0 and a port on \
which that process should launch a coordinator service \
(default: 127.0.0.1:1234)")
parser.add_argument("--num-process",
type=int,
default=1,
help="number of processes (default: 1)")
parser.add_argument("--process-id",
type=int,
default=0,
help="the ID number of the current process (default: 0)")
return parser.parse_args(args)
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, reason])
def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, reason
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
num_gpu, gpu_has_fp8, reason = unittest_query_gpu()
def exec(self, use_fp8):
"""Run 3 epochs for testing"""
num_gpu = self.num_gpu
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size
arg_list = []
for i in range(num_gpu):
args = encoder_parser([])
args.num_process = num_gpu
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.process_id = i
arg_list.append(args)
with mp.Pool(self.num_gpu) as p:
results = p.map(train_and_evaluate, arg_list)
return results
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
results = self.exec(True)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
""" Encoder training on single GPU""" """Encoder training on single GPU"""
import argparse import argparse
import os
import unittest import unittest
from functools import partial from functools import partial
...@@ -12,8 +11,7 @@ import jax.numpy as jnp ...@@ -12,8 +11,7 @@ import jax.numpy as jnp
import nltk import nltk
import numpy as np import numpy as np
import optax import optax
import tensorflow_datasets as tfds from datasets import load_dataset
from cuda import cudart
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
...@@ -25,19 +23,6 @@ DROPOUT_KEY = 'dropout' ...@@ -25,19 +23,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng' INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module): class Net(nn.Module):
"""NLP Encoder""" """NLP Encoder"""
num_embed: int num_embed: int
...@@ -53,7 +38,7 @@ class Net(nn.Module): ...@@ -53,7 +38,7 @@ class Net(nn.Module):
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER, layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
...@@ -158,12 +143,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -158,12 +143,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt') nltk.download('punkt')
dataset_size = len(dataset['sentence']) dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']): for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8")) tokens = nltk.word_tokenize(sentence)
tensor = output[j] tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens): for i, word in enumerate(tokens):
if i >= max_seq_len: if i >= max_seq_len:
...@@ -176,26 +160,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len): ...@@ -176,26 +160,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else: else:
tensor[i] = vocab[word] tensor[i] = vocab[word]
mask_1d[0, i] = 1 seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j] mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d) mask_2d[:seq_len, :seq_len] = 0
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output new_dataset = {
dataset['label'] = dataset['label'].astype(np.float32) 'sentence': output,
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)) 'label': dataset['label'].astype(np.float32),
return dataset, vocab, word_id 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len): def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory.""" """Load GLUE train and test datasets into memory."""
vocab = {} vocab = {}
word_id = 0 word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1)) train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len) train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len) test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id return train_ds, test_ds, word_id
...@@ -210,11 +199,8 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -210,11 +199,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args) print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -226,7 +212,6 @@ def train_and_evaluate(args): ...@@ -226,7 +212,6 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8): with te.fp8_autocast(enabled=args.use_fp8):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
...@@ -322,6 +307,8 @@ def encoder_parser(args): ...@@ -322,6 +307,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Run 4 epochs for testing""" """Run 4 epochs for testing"""
...@@ -332,7 +319,7 @@ class TestEncoder(unittest.TestCase): ...@@ -332,7 +319,7 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self): def test_te_fp8(self):
"""Test Transformer Engine with FP8""" """Test Transformer Engine with FP8"""
self.args.use_fp8 = True self.args.use_fp8 = True
......
datasets
flax flax
optax optax
tensorflow-datasets Pillow
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
""" MNIST training on single GPU""" """ MNIST training on single GPU"""
import argparse import argparse
import os
import unittest import unittest
from functools import partial from functools import partial
...@@ -11,8 +10,7 @@ import jax ...@@ -11,8 +10,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import optax import optax
import tensorflow_datasets as tfds from datasets import load_dataset
from cuda import cudart
from flax import linen as nn from flax import linen as nn
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.training import train_state from flax.training import train_state
...@@ -27,19 +25,6 @@ DROPOUT_KEY = 'dropout' ...@@ -27,19 +25,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng' INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module): class Net(nn.Module):
"""CNN model for MNIST.""" """CNN model for MNIST."""
use_te: bool = False use_te: bool = False
...@@ -144,13 +129,23 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -144,13 +129,23 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets(): def get_datasets():
"""Load MNIST train and test datasets into memory.""" """Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist') train_ds = load_dataset('mnist', split='train')
ds_builder.download_and_prepare() train_ds.set_format(type='np')
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) batch_size = train_ds['image'].shape[0]
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
train_ds['image'] = jnp.float32(train_ds['image']) / 255. new_train_ds = {
test_ds['image'] = jnp.float32(test_ds['image']) / 255. 'image': train_ds['image'].astype(np.float32).reshape(shape) / 255.,
return train_ds, test_ds 'label': train_ds['label']
}
test_ds = load_dataset('mnist', split='test')
test_ds.set_format(type='np')
batch_size = test_ds['image'].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = {
'image': test_ds['image'].astype(np.float32).reshape(shape) / 255.,
'label': test_ds['label']
}
return new_train_ds, new_test_ds
def check_fp8(state, var_collect, input_shape, label_shape): def check_fp8(state, var_collect, input_shape, label_shape):
...@@ -162,11 +157,9 @@ def check_fp8(state, var_collect, input_shape, label_shape): ...@@ -162,11 +157,9 @@ def check_fp8(state, var_collect, input_shape, label_shape):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args) print(args)
if args.use_fp8: if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
args.use_te = True args.use_te = True
train_ds, test_ds = get_datasets() train_ds, test_ds = get_datasets()
...@@ -275,6 +268,8 @@ def mnist_parser(args): ...@@ -275,6 +268,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Run MNIST without Transformer Engine""" """Run MNIST without Transformer Engine"""
...@@ -299,7 +294,7 @@ class TestMNIST(unittest.TestCase): ...@@ -299,7 +294,7 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
self.verify(actual) self.verify(actual)
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self): def test_te_fp8(self):
"""Test Transformer Engine with FP8""" """Test Transformer Engine with FP8"""
self.args.use_fp8 = True self.args.use_fp8 = True
......
...@@ -9,4 +9,5 @@ pytest -Wignore -v $TE_PATH/tests/jax ...@@ -9,4 +9,5 @@ pytest -Wignore -v $TE_PATH/tests/jax
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax pytest -Wignore -v $TE_PATH/examples/jax --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
...@@ -13,13 +13,14 @@ from jax import lax ...@@ -13,13 +13,14 @@ from jax import lax
from jax import jit, value_and_grad from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose, is_fp8_supported from utils import assert_allclose
from transformer_engine.common.recipe import Format from transformer_engine.common.recipe import Format
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8 from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.cpp_extensions import dequantize, quantize from transformer_engine.jax.cpp_extensions import dequantize, quantize
from transformer_engine.jax.dot import fp8_dot from transformer_engine.jax.dot import fp8_dot
from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp from transformer_engine.jax.mlp import fp8_ln_mlp
...@@ -28,11 +29,12 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, ...@@ -28,11 +29,12 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)] FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)] LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32] DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
class TestFP8Dot: class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_qdq(self): def test_qdq(self):
FP8_E4M3_MAX = 448 FP8_E4M3_MAX = 448
x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32) x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
...@@ -74,7 +76,7 @@ class TestFP8Dot: ...@@ -74,7 +76,7 @@ class TestFP8Dot:
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
value_n_grad_func_compiled(a, b) value_n_grad_func_compiled(a, b)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_compile_fp8(self, compute_type): def test_compile_fp8(self, compute_type):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -115,7 +117,7 @@ class TestFP8Dot: ...@@ -115,7 +117,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_forward_fp8_randint(self, m, n, k, compute_type): def test_forward_fp8_randint(self, m, n, k, compute_type):
...@@ -180,7 +182,7 @@ class TestFP8Dot: ...@@ -180,7 +182,7 @@ class TestFP8Dot:
assert_allclose(primitive_a_grad, ref_a_grad) assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5) assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE) @pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_grad_fp8_randint(self, m, n, k, compute_type): def test_grad_fp8_randint(self, m, n, k, compute_type):
...@@ -252,7 +254,7 @@ class TestFP8Dot: ...@@ -252,7 +254,7 @@ class TestFP8Dot:
assert_allclose(primitive_a_grad, ref_a_grad) assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad) assert_allclose(primitive_b_grad, ref_b_grad)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]) (16384, 1024, 1024)])
def test_grad_fp8_mlp_randint(self, m, n, k): def test_grad_fp8_mlp_randint(self, m, n, k):
...@@ -464,7 +466,7 @@ class TestGatedGeLuFP8(TestGatedGeLu): ...@@ -464,7 +466,7 @@ class TestGatedGeLuFP8(TestGatedGeLu):
return func(inputs, no_use, no_use) return func(inputs, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 64), (64, 256)]) @pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
def test_gated_gelu(self, random_inputs): def test_gated_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32) self.amax = jnp.zeros(1, jnp.float32)
......
...@@ -10,19 +10,21 @@ import jax.numpy as jnp ...@@ -10,19 +10,21 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from jax.experimental import maps from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource from transformer_engine.jax.sharding import ShardingResource
is_fp8_supported, reason = is_fp8_available()
class TestFP8Helper(unittest.TestCase): class TestFP8Helper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self): def test_initialize(self):
margin = 5.0 margin = 5.0
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
...@@ -52,7 +54,7 @@ class TestFP8Helper(unittest.TestCase): ...@@ -52,7 +54,7 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.finalize() FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_fp8_metas(self): def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_len=3) FP8Helper.initialize(margin=3.0, amax_history_len=3)
...@@ -113,7 +115,7 @@ class TestFP8Helper(unittest.TestCase): ...@@ -113,7 +115,7 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.finalize() FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_generate_fp8_max_array(self): def test_generate_fp8_max_array(self):
num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2 num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2
...@@ -131,7 +133,7 @@ class TestFP8Helper(unittest.TestCase): ...@@ -131,7 +133,7 @@ class TestFP8Helper(unittest.TestCase):
assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta)) assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
FP8Helper.finalize() FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self): def test_update_collections(self):
original_val = 0.0 original_val = 0.0
updated_val = 10.0 updated_val = 10.0
...@@ -163,7 +165,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -163,7 +165,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self): def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
...@@ -188,7 +190,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -188,7 +190,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state() self._check_defult_state()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
......
...@@ -11,11 +11,13 @@ import pytest ...@@ -11,11 +11,13 @@ import pytest
from transformer_engine.common.recipe import Format from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose, is_fp8_supported from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer from utils import EncoderLayer as RefEncoderLayer
is_fp8_supported, reason = is_fp8_available()
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs): def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs) output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
...@@ -289,7 +291,7 @@ class TestEncoderLayer: ...@@ -289,7 +291,7 @@ class TestEncoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
...@@ -306,7 +308,7 @@ class TestEncoderLayer: ...@@ -306,7 +308,7 @@ class TestEncoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
...@@ -516,7 +518,7 @@ class TestDecoderLayer: ...@@ -516,7 +518,7 @@ class TestDecoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
...@@ -533,7 +535,7 @@ class TestDecoderLayer: ...@@ -533,7 +535,7 @@ class TestDecoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from jax.experimental import maps from jax.experimental import maps
from utils import is_devices_enough
from transformer_engine.jax.flax import extend_logical_axis_rules from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta from transformer_engine.jax.sharding import get_elementwise_sharding_meta
...@@ -15,7 +16,6 @@ from transformer_engine.jax.sharding import global_shard_guard ...@@ -15,7 +16,6 @@ from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
from utils import is_devices_enough
def _get_sharding_resource(mesh_names, sharding_type): def _get_sharding_resource(mesh_names, sharding_type):
......
...@@ -9,7 +9,6 @@ from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional ...@@ -9,7 +9,6 @@ from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from cuda import cudart
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import lax, vmap from jax import lax, vmap
...@@ -25,20 +24,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci ...@@ -25,20 +24,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_fp8_supported():
"""
Thus JAX doesn't have API to query capability
Use cuda-python for get the compute capability
"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
ret, sm_major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
assert ret == cudaSuccess
return sm_major >= 9
def is_devices_enough(required): def is_devices_enough(required):
return len(jax.devices()) >= required return len(jax.devices()) >= required
......
...@@ -6,6 +6,7 @@ pybind11_add_module( ...@@ -6,6 +6,7 @@ pybind11_add_module(
transformer_engine_jax transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cpp
) )
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine) target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <cublasLt.h>
#include "common/include/transformer_engine/fused_attn.h" #include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/transformer_engine.h" #include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h" #include "jax/csrc/modules.h"
...@@ -58,6 +60,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -58,6 +60,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable); m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "utils.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
int GetDeviceComputeCapability(int gpu_id) {
int max_num_gpu = 0;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&max_num_gpu));
assert(gpu_id < max_num_gpu);
int major = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, gpu_id));
int minor = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, gpu_id));
int gpu_arch = major * 10 + minor;
return gpu_arch;
}
} // namespace jax
} // namespace transformer_engine
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
class cublasLtMetaManager { class cublasLtMetaManager {
public: public:
static cublasLtMetaManager &Instance() { static cublasLtMetaManager &Instance() {
......
...@@ -13,13 +13,50 @@ import jax.numpy as jnp ...@@ -13,13 +13,50 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource from transformer_engine.jax.sharding import ShardingResource
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict] Collection = Union[Dict, FrozenDict]
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
gpu_arch = get_device_compute_capability(gpu_id)
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if get_cuda_version() < 12010:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if gpu_id is not None:
return _check_fp8_support(gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = True
devices = jax.local_devices()
for gpu in devices:
ret, msg = _check_fp8_support(gpu.id)
if ret is False:
_is_fp8_available = ret
_reason_for_no_fp8 = msg
break
return _is_fp8_available, _reason_for_no_fp8
def _format2dtypes(format_: Format): def _format2dtypes(format_: Format):
if format_ == Format.E4M3: if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3 return DType.kFloat8E4M3, DType.kFloat8E4M3
...@@ -332,6 +369,9 @@ def fp8_autocast(enabled: bool = False, ...@@ -332,6 +369,9 @@ def fp8_autocast(enabled: bool = False,
try: try:
with global_shard_guard(sharding_resource): with global_shard_guard(sharding_resource):
if enabled: if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max': if fp8_recipe.amax_compute_algo == 'max':
amax_compute_algo = AmaxComputeAlgo.MAX amax_compute_algo = AmaxComputeAlgo.MAX
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment