Commit ff373a11 authored by xyupeng's avatar xyupeng Committed by Frank Lee
Browse files

[NFC] polish tests/test_layers/test_sequence/checks_seq/check_layer_seq.py code style (#1723)

parent 7e62af28
......@@ -12,15 +12,10 @@ def check_selfattention():
BATCH = 4
HIDDEN_SIZE = 16
layer = TransformerSelfAttentionRing(
16,
8,
8,
0.1
)
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
layer = layer.to(get_current_device())
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
get_current_device())
attention_mask = torch.randint(low=0, high=2,
size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device())
out = layer(hidden_states, attention_mask)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment