test_layer.py 771 Bytes
Newer Older
zbian's avatar
zbian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device


def check_selfattention():
    WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE)
    SUB_SEQ_LENGTH = 8
    BATCH = 4
    HIDDEN_SIZE = 16

    layer = TransformerSelfAttentionRing(
        16,
        8,
        8,
        0.1
    )
    layer = layer.to(get_current_device())

    hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
    attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
        get_current_device())
    out = layer(hidden_states, attention_mask)