test_functions.py 3.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest

import jax
import jax.numpy as jnp

from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope


class TestLoRA:

    def reference(x, la, lb, pattern, scale):
        out = jnp.einsum(pattern, x, la, lb)
        return out * scale

    @pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
    @pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
    @pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
                                                       ((-1,), (3, 1024), '...h,hkr,krz->...kz')])
    @pytest.mark.parametrize('rank', [32, 16])
    @pytest.mark.parametrize('alpha', [None, 4, 8])
    def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
        axis, features, pattern = axis_features_pattern
        axis = _normalize_axes(axis, len(shape))
        shape_in_axis = tuple(shape[ax] for ax in axis)

        key = jax.random.key(1124)
        key, x_key = jax.random.split(key)
        x = jax.random.normal(x_key, shape, dtype)

        key, la_key = jax.random.split(key)
        la_shape = (*shape_in_axis, *features[:-1], rank)
        la = jax.random.normal(la_key, la_shape, dtype)

        key, lb_key = jax.random.split(key)
        lb_shape = (*features[:-1], rank, features[-1])
        lb = jax.random.normal(lb_key, lb_shape, dtype)

        out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
        scale_ref = alpha / rank if alpha is not None else 1.0
        out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)

        assert_allclose(out_target, out_ref, dtype=dtype)

    @pytest.mark.parametrize('scope_ref_assert',
                             [('none', LoRAScope(False, False, False), False),
                              ('all', LoRAScope(True, True, True), False),
                              ('qkv_proj', LoRAScope(True, False, False), False),
                              ('output_proj', LoRAScope(False, True, False), False),
                              ('mlp', LoRAScope(False, False, True), False),
                              ('exclude_qkv_proj', LoRAScope(False, True, True), False),
                              ('exclude_output_proj', LoRAScope(True, False, True), False),
                              ('exclude_mlp', LoRAScope(True, True, False), False),
                              ('messing_up', LoRAScope(), True)])
    def test_lora_scope_generator(self, scope_ref_assert):
        scope, reference, need_assert = scope_ref_assert
        try:
            lora_scope = _canonicalize_lora_scope(scope)
            assert lora_scope == reference
        except AssertionError as ae:
            assert need_assert, f"{ae.args}"