test_functions.py 2.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp

from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope


class TestLoRA:

    def reference(x, la, lb, pattern, scale):
        out = jnp.einsum(pattern, x, la, lb)
        return out * scale

23
24
25
26
27
28
29
30
    @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
    @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
    @pytest.mark.parametrize(
        "axis_features_pattern",
        [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
    )
    @pytest.mark.parametrize("rank", [32, 16])
    @pytest.mark.parametrize("alpha", [None, 4, 8])
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
        axis, features, pattern = axis_features_pattern
        axis = _normalize_axes(axis, len(shape))
        shape_in_axis = tuple(shape[ax] for ax in axis)

        key = jax.random.key(1124)
        key, x_key = jax.random.split(key)
        x = jax.random.normal(x_key, shape, dtype)

        key, la_key = jax.random.split(key)
        la_shape = (*shape_in_axis, *features[:-1], rank)
        la = jax.random.normal(la_key, la_shape, dtype)

        key, lb_key = jax.random.split(key)
        lb_shape = (*features[:-1], rank, features[-1])
        lb = jax.random.normal(lb_key, lb_shape, dtype)

        out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
        scale_ref = alpha / rank if alpha is not None else 1.0
        out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)

        assert_allclose(out_target, out_ref, dtype=dtype)

54
55
56
57
58
59
60
61
62
63
64
65
66
67
    @pytest.mark.parametrize(
        "scope_ref_assert",
        [
            ("none", LoRAScope(False, False, False), False),
            ("all", LoRAScope(True, True, True), False),
            ("qkv_proj", LoRAScope(True, False, False), False),
            ("output_proj", LoRAScope(False, True, False), False),
            ("mlp", LoRAScope(False, False, True), False),
            ("exclude_qkv_proj", LoRAScope(False, True, True), False),
            ("exclude_output_proj", LoRAScope(True, False, True), False),
            ("exclude_mlp", LoRAScope(True, True, False), False),
            ("messing_up", LoRAScope(), True),
        ],
    )
68
69
70
71
72
73
74
    def test_lora_scope_generator(self, scope_ref_assert):
        scope, reference, need_assert = scope_ref_assert
        try:
            lora_scope = _canonicalize_lora_scope(scope)
            assert lora_scope == reference
        except AssertionError as ae:
            assert need_assert, f"{ae.args}"