# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import pytest import jax import jax.numpy as jnp from utils import assert_allclose from transformer_engine.jax.flax.module import _apply_low_rank_adaptation from transformer_engine.jax.flax.module import _normalize_axes from transformer_engine.jax.flax.transformer import LoRAScope from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope class TestLoRA: def reference(x, la, lb, pattern, scale): out = jnp.einsum(pattern, x, la, lb) return out * scale @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize( "axis_features_pattern", [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")], ) @pytest.mark.parametrize("rank", [32, 16]) @pytest.mark.parametrize("alpha", [None, 4, 8]) def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): axis, features, pattern = axis_features_pattern axis = _normalize_axes(axis, len(shape)) shape_in_axis = tuple(shape[ax] for ax in axis) key = jax.random.key(1124) key, x_key = jax.random.split(key) x = jax.random.normal(x_key, shape, dtype) key, la_key = jax.random.split(key) la_shape = (*shape_in_axis, *features[:-1], rank) la = jax.random.normal(la_key, la_shape, dtype) key, lb_key = jax.random.split(key) lb_shape = (*features[:-1], rank, features[-1]) lb = jax.random.normal(lb_key, lb_shape, dtype) out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha) scale_ref = alpha / rank if alpha is not None else 1.0 out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref) assert_allclose(out_target, out_ref, dtype=dtype) @pytest.mark.parametrize( "scope_ref_assert", [ ("none", LoRAScope(False, False, False), False), ("all", LoRAScope(True, True, True), False), ("qkv_proj", LoRAScope(True, False, False), False), ("output_proj", LoRAScope(False, True, False), False), ("mlp", LoRAScope(False, False, True), False), ("exclude_qkv_proj", LoRAScope(False, True, True), False), ("exclude_output_proj", LoRAScope(True, False, True), False), ("exclude_mlp", LoRAScope(True, True, False), False), ("messing_up", LoRAScope(), True), ], ) def test_lora_scope_generator(self, scope_ref_assert): scope, reference, need_assert = scope_ref_assert try: lora_scope = _canonicalize_lora_scope(scope) assert lora_scope == reference except AssertionError as ae: assert need_assert, f"{ae.args}"