Unverified Commit 496b8fdd authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

[JAX] add multiprocessing example and improve debugging message (#198)



* add mp example
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update doc-string
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* better FP8 checker
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace te.* with te.flax* to remove deprecated warning
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse os.environ
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix typo
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/test_multiprocessing_encoder.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove cuda-python
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* adjust readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update examples/jax/encoder/README.md
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix cpp lint

fix issue of "Could not find a newline character at the end of the file."
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix AssertionError: 1 GPU per process
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace tfds with datasets

The Flax application crash if it use TensorFlow Dataset (tfds) in NVIDIA JAX container.
The tfds is very useful for downloading well-knwon dataset (e.g., MNIST, GLUE) and commonly used by TF/JAX community.
However, it seems like that it is NOT compatible with NVIDIA TensorFlow in NVIDIA JAX container and somehow affects JAX.
It triggers random errors at JAX initialization depending on different versions, and make CI unstable.
Thus, this commit replaces tfds with "huggingface datasets" to download needed datasets.
See "nvbugs 4039266" for more details.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix input sharding

Unlike SPMD mode, in multiprocessing mode, the input tensor must be sharded manually.
Using DP=4, TP=2 as an example, the device mesh looks like:

mesh.device_ids = [[0, 1],
                   [2, 3],
                   [4, 5],
                   [6, 7]]

Assume that the process ID is mapped to GPU ID.
The process 0 and process 1 are grouped for model parallelism,
process 2 and process 3 are grouped together too, and so on.

The process 0 and process 1 need to share the same micro-batch in the training step,
process 0 and process 2, 4, and 6 have different micro-batch.

Thus, `shard_array_wrapper` partitions inputs to 4 parts (and setup
needed arguments for jax.make_array_from_single_device_arrays).
The process 0 and process 1 take the first quarter,
process 2 and process 3 take the second quarter, and so on.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor UT for multiprocess example

Use Python `multiprocessing` to test the multiprocessing example,
if the system has multiple GPU. 1 GPU per process.

Because `jax.distributed.initialize` must be called before any other JAX or Flax API,
GPU info cannot be queried by calling jax.local_devices() in TestEncoder.
Thus, `unittest_query_gpu()` forks another process to query number of GPUs and
FP8 capability.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove nouse arg `--num-gpu`

JAX doesn't have an API to setup number of GPU used in SPMD mode.
The only way is to use `CUDA_VISIBLE_DEVICES` for now.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix typo
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix ut
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* simplify the mask setting
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* increase batch-size for multigpu example

The batch-size 64 is too small to be partitioned for 8xH100.
If batch-size is 64, the GEMM shape is 256x8192x8 per GPU.
The 8 is too small for FP8 GEMM kernel, and
cuBLASLt will throw "Failed to query heuristics".
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix downloading mnist error

To download MNIST via `huggingface datasets`, it requires Pillow.
Otherwise, it throws `An error occurred while generating the
dataset`
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 441fa968
......@@ -35,7 +35,14 @@ python test_single_gpu_encoder.py --use-fp8
6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
7. The `train_step` and `eval_step` also needs to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then:
```sh
export CUDA_VISIBLE_DEVICES=0,1,2,3
python test_multigpu_encoder.py
```
Please refer to [CUDA Environment Variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables) for more details.
### Run ###
......@@ -60,6 +67,10 @@ python test_multigpu_encoder.py --use-fp8
import os
os.environ['XLA_FLAGS'] = "--xla_dump_hlo_as_proto --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=<path to store XLA HLO>"
```
5. If the model parallelism example is run in the container, it is recommended to add `--ipc=host` in launch arguments. Otherwise, it might trigger UCX errors.
```sh
docker run --gpus=all --ipc=host ...
```
### Run ###
......@@ -67,3 +78,62 @@ python test_multigpu_encoder.py --use-fp8
python test_model_parallel_encoder.py
python test_model_parallel_encoder.py --use-fp8
```
## Multiple Processes with Model Parallelism ##
1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process.
2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
3. The quick way to check system topology is to use `nvidia-smi`, for example:
```sh
$ nvidia-smi topo -mp
CPU Affinity NUMA Affinity
GPU0 48-63,176-191 3
GPU1 48-63,176-191 3
GPU2 16-31,144-159 1
GPU3 16-31,144-159 1
GPU4 112-127,240-255 7
GPU5 112-127,240-255 7
GPU6 80-95,208-223 5
GPU7 80-95,208-223 5
```
4. It is recommended to set the environment variable `CUDA_DEVICE_ORDER` to `PCI_BUS_ID` before running the example with the affinity setting. To ensure that the device order is aligned between CUDA and `nvidia-smi`. Please refer to [CUDA Environment Variables](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables) for more details.
5. `jax.distributed.initialize` must be called before any other JAX or Flax API, otherwise `jax.local_devices` will be incorrect. `jax.distributed.shutdown` should be the last API call.
6. Unlike SPMD, the input tensor must be sharded manually and be wrapped by `jax.make_array_from_single_device_arrays`. Otherwise, the sharding will be incorrect. Using DP=4, TP=2 as an example, the device mesh looks like:
```python
mesh.device_ids = [[0, 1],
[2, 3],
[4, 5],
[6, 7]]
```
Assume that the process ID is mapped to GPU ID. The process 0 and process 1 are grouped for model parallelism, the process 2 and process 3 are grouped together too, and so on. Thus, process 0 and process 1 need to share the same micro-batch in the training step, process 0 and process 2, 4, and 6 have different micro-batch.
### Run ###
If the system has 8 GPUs, the basic commands are:
```bash
python test_multiprocessing_encoder.py --num-process 8 --process-id 0 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 1 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 2 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 3 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 4 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 5 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 6 &
python test_multiprocessing_encoder.py --num-process 8 --process-id 7 &
```
The correct setting for hardware affinity is system dependent. Taking the above system topology as an example, the command can be:
```bash
numactl --cpunodebind=48 --membind=3 python test_multiprocessing_encoder.py --num-process 8 --process-id 0 &
numactl --cpunodebind=49 --membind=3 python test_multiprocessing_encoder.py --num-process 8 --process-id 1 &
numactl --cpunodebind=16 --membind=1 python test_multiprocessing_encoder.py --num-process 8 --process-id 2 &
numactl --cpunodebind=17 --membind=1 python test_multiprocessing_encoder.py --num-process 8 --process-id 3 &
numactl --cpunodebind=112 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 4 &
numactl --cpunodebind=113 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 5 &
numactl --cpunodebind=80 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 6 &
numactl --cpunodebind=81 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 7 &
```
\ No newline at end of file
datasets
flax
nltk
optax
tensorflow-datasets
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on multi-GPU with tesnor parallelism"""
"""Encoder training on multi-GPU with tesnor parallelism"""
import argparse
import unittest
from functools import partial
......@@ -11,10 +11,10 @@ import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
......@@ -31,26 +31,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
......@@ -66,7 +46,7 @@ class Net(nn.Module):
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -177,12 +157,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
......@@ -195,26 +174,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
mask_2d[:seq_len, :seq_len] = 0
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -238,7 +222,7 @@ def get_params_pspec(sharding_rules, abs_var_collect):
return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes))
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec
......@@ -257,14 +241,12 @@ def get_state_pspec(state, params_pspec):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
check_num_gpu(args.num_gpu)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
num_gpu_tp = 2
if args.num_gpu % num_gpu_tp == 0:
num_gpu_dp = args.num_gpu // num_gpu_tp
if num_gpu % num_gpu_tp == 0:
num_gpu_dp = num_gpu // num_gpu_tp
else:
num_gpu_dp = 1
num_gpu_tp = 1
......@@ -287,14 +269,13 @@ def train_and_evaluate(args):
with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.extend_logical_axis_rules(tuple()) + customized_rules
sharding_rules = te.flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......@@ -356,13 +337,6 @@ def train_and_evaluate(args):
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument(
"--batch-size",
type=int,
......@@ -416,20 +390,19 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
num_gpu = len(jax.local_devices())
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
cls.args = encoder_parser(["--epochs", "3"])
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on multi-GPU with data parallelism"""
"""Encoder training on multi-GPU with data parallelism"""
import argparse
import unittest
from functools import partial
......@@ -11,10 +11,10 @@ import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
......@@ -28,26 +28,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
......@@ -63,7 +43,7 @@ class Net(nn.Module):
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -168,12 +148,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
......@@ -186,26 +165,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
mask_2d[:seq_len, :seq_len] = 0
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -229,7 +213,7 @@ def get_params_pspec(sharding_rules, abs_var_collect):
return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes))
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec
......@@ -248,15 +232,14 @@ def get_state_pspec(state, params_pspec):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
check_num_gpu(args.num_gpu)
assert args.batch_size % args.num_gpu == 0, f"Batch size needs to be multiple of {args.num_gpu}"
assert args.test_batch_size % args.num_gpu == 0, \
f"Test batch size needs to be multiple of {args.num_gpu}"
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, \
f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((args.num_gpu,))
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
rng = jax.random.PRNGKey(args.seed)
......@@ -269,13 +252,12 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te.extend_logical_axis_rules(tuple())
sharding_rules = te.flax.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......@@ -337,26 +319,19 @@ def train_and_evaluate(args):
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for training (default: 64)",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for testing (default: 64)",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--max-seq-len",
......@@ -397,25 +372,24 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
num_gpu = len(jax.local_devices())
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
cls.args = encoder_parser(["--epochs", "3"])
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.49 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.49 and actual[1] > 0.76
if __name__ == "__main__":
......
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on single GPU"""
"""Encoder training on single GPU"""
import argparse
import os
import unittest
from functools import partial
......@@ -12,8 +11,7 @@ import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
......@@ -25,19 +23,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
......@@ -53,7 +38,7 @@ class Net(nn.Module):
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
layer_type=te.flax.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -158,12 +143,11 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
......@@ -176,26 +160,31 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
seq_len = len(tokens)
if seq_len > max_seq_len:
seq_len = max_seq_len
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
mask_2d[:seq_len, :seq_len] = 0
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
new_dataset = {
'sentence': output,
'label': dataset['label'].astype(np.float32),
'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds = load_dataset('glue', 'cola', split='train')
train_ds.set_format(type='np')
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds = load_dataset('glue', 'cola', split='validation')
test_ds.set_format(type='np')
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......@@ -210,11 +199,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......@@ -226,7 +212,6 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
......@@ -322,6 +307,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 4 epochs for testing"""
......@@ -332,7 +319,7 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
......
datasets
flax
optax
tensorflow-datasets
Pillow
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
""" MNIST training on single GPU"""
import argparse
import os
import unittest
from functools import partial
......@@ -11,8 +10,7 @@ import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from datasets import load_dataset
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
......@@ -27,19 +25,6 @@ DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""CNN model for MNIST."""
use_te: bool = False
......@@ -144,13 +129,23 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
train_ds = load_dataset('mnist', split='train')
train_ds.set_format(type='np')
batch_size = train_ds['image'].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_train_ds = {
'image': train_ds['image'].astype(np.float32).reshape(shape) / 255.,
'label': train_ds['label']
}
test_ds = load_dataset('mnist', split='test')
test_ds.set_format(type='np')
batch_size = test_ds['image'].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = {
'image': test_ds['image'].astype(np.float32).reshape(shape) / 255.,
'label': test_ds['label']
}
return new_train_ds, new_test_ds
def check_fp8(state, var_collect, input_shape, label_shape):
......@@ -162,11 +157,9 @@ def check_fp8(state, var_collect, input_shape, label_shape):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
args.use_te = True
train_ds, test_ds = get_datasets()
......@@ -275,6 +268,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run MNIST without Transformer Engine"""
......@@ -299,7 +294,7 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
......
......@@ -9,4 +9,5 @@ pytest -Wignore -v $TE_PATH/tests/jax
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax
pytest -Wignore -v $TE_PATH/examples/jax --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
......@@ -13,13 +13,14 @@ from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose, is_fp8_supported
from utils import assert_allclose
from transformer_engine.common.recipe import Format
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.cpp_extensions import dequantize, quantize
from transformer_engine.jax.dot import fp8_dot
from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp
......@@ -28,11 +29,12 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_qdq(self):
FP8_E4M3_MAX = 448
x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
......@@ -74,7 +76,7 @@ class TestFP8Dot:
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
value_n_grad_func_compiled(a, b)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_compile_fp8(self, compute_type):
key = jax.random.PRNGKey(0)
......@@ -115,7 +117,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_forward_fp8_randint(self, m, n, k, compute_type):
......@@ -180,7 +182,7 @@ class TestFP8Dot:
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_grad_fp8_randint(self, m, n, k, compute_type):
......@@ -252,7 +254,7 @@ class TestFP8Dot:
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_fp8_mlp_randint(self, m, n, k):
......@@ -464,7 +466,7 @@ class TestGatedGeLuFP8(TestGatedGeLu):
return func(inputs, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
def test_gated_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
......
......@@ -10,19 +10,21 @@ import jax.numpy as jnp
import numpy as np
from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource
is_fp8_supported, reason = is_fp8_available()
class TestFP8Helper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
......@@ -52,7 +54,7 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_len=3)
......@@ -113,7 +115,7 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_generate_fp8_max_array(self):
num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2
......@@ -131,7 +133,7 @@ class TestFP8Helper(unittest.TestCase):
assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
......@@ -163,7 +165,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
......@@ -188,7 +190,7 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
......
......@@ -11,11 +11,13 @@ import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper
from utils import assert_allclose, is_fp8_supported
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from utils import assert_allclose
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
is_fp8_supported, reason = is_fp8_available()
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
......@@ -289,7 +291,7 @@ class TestEncoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
......@@ -306,7 +308,7 @@ class TestEncoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
......@@ -516,7 +518,7 @@ class TestDecoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
......@@ -533,7 +535,7 @@ class TestDecoderLayer:
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
......
......@@ -7,6 +7,7 @@ import numpy as np
import pytest
from jax.experimental import maps
from utils import is_devices_enough
from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
......@@ -15,7 +16,6 @@ from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
from utils import is_devices_enough
def _get_sharding_resource(mesh_names, sharding_type):
......
......@@ -9,7 +9,6 @@ from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional
import jax
import jax.numpy as jnp
import numpy as np
from cuda import cudart
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax, vmap
......@@ -25,20 +24,6 @@ PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Preci
Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_fp8_supported():
"""
Thus JAX doesn't have API to query capability
Use cuda-python for get the compute capability
"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
ret, sm_major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
assert ret == cudaSuccess
return sm_major >= 9
def is_devices_enough(required):
return len(jax.devices()) >= required
......
......@@ -6,6 +6,7 @@ pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cpp
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
......@@ -6,6 +6,8 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cublasLt.h>
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
......@@ -58,6 +60,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
m.def("is_fused_attn_kernel_available", &IsFusedAttnKernelAvailable);
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "utils.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
int GetDeviceComputeCapability(int gpu_id) {
int max_num_gpu = 0;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&max_num_gpu));
assert(gpu_id < max_num_gpu);
int major = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, gpu_id));
int minor = 0;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, gpu_id));
int gpu_arch = major * 10 + minor;
return gpu_arch;
}
} // namespace jax
} // namespace transformer_engine
......@@ -18,6 +18,9 @@
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
class cublasLtMetaManager {
public:
static cublasLtMetaManager &Instance() {
......
......@@ -13,13 +13,50 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
def _check_fp8_support(gpu_id) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
gpu_arch = get_device_compute_capability(gpu_id)
if gpu_arch >= 90: # hopper and above
return True, ""
if gpu_arch < 89: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if get_cuda_version() < 12010:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available(gpu_id=None) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if gpu_id is not None:
return _check_fp8_support(gpu_id)
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available = True
devices = jax.local_devices()
for gpu in devices:
ret, msg = _check_fp8_support(gpu.id)
if ret is False:
_is_fp8_available = ret
_reason_for_no_fp8 = msg
break
return _is_fp8_available, _reason_for_no_fp8
def _format2dtypes(format_: Format):
if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3
......@@ -332,6 +369,9 @@ def fp8_autocast(enabled: bool = False,
try:
with global_shard_guard(sharding_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max':
amax_compute_algo = AmaxComputeAlgo.MAX
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment