test_linear.py 1.11 KB
Newer Older
yongshk's avatar
yongshk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import candle
from candle import Tensor
from candle.nn import Linear


def test_linear_layer_can_be_constructed():
    linear = Linear(10, 10)
    assert linear is not None


def test_linear_layer_can_forward_a_singular_input():
    linear = Linear(384, 1536)
    input_tensor = candle.randn((8, 384))
    output = linear.forward(input_tensor)
    assert output.shape == (8, 1536)


def test_linear_layer_can_forward_a_batched_input():
    linear = Linear(384, 1536)
    input_tensor = candle.randn((16, 8, 384))
    output = linear.forward(input_tensor)
    assert output.shape == (16, 8, 1536)


def test_quantized_linear_layer_can_forward_a_singular_input():
    linear = Linear(384, 1536)
    linear.weight = linear.weight.quantize("q4_0")
    input_tensor = candle.randn((8, 384))
    output = linear.forward(input_tensor)
    assert output.shape == (8, 1536)


def test_quantized_linear_layer_can_forward_a_batched_input():
    linear = Linear(384, 1536)
    linear.weight = linear.weight.quantize("q4_0")
    input_tensor = candle.randn((16, 8, 384))
    output = linear.forward(input_tensor)
    assert output.shape == (16, 8, 1536)