Unverified Commit 19e1a5cf authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
parent 9a3321e9
......@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster,
)
stage_manager = booster.plugin.stage_manager
......@@ -46,11 +52,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
gptj,
sharded_gptj,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
row_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
gptj,
sharded_gptj,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
......@@ -77,7 +97,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
check_weight(
gptj,
sharded_gptj,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
......@@ -110,14 +139,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
......@@ -125,7 +154,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"enable_all_optimization": False,
#'use_lazy_init': True,
"precision": "fp32",
},
......@@ -154,7 +183,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
def run_gptj_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......@@ -189,7 +224,13 @@ def run_gptj_test(test_config):
def run_gptj_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......@@ -198,15 +239,30 @@ def run_gptj_3d_test(test_config):
def check_gptj(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_gptj_test()
def check_gptj_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_gptj_3d_test()
@pytest.mark.skip("TODO check_gptj has something wrong.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
......
......@@ -112,7 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
......@@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,
"pp_size": 1,
......
......@@ -29,7 +29,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster,
)
stage_manager = booster.plugin.stage_manager
......@@ -39,7 +45,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
opt_model = unwrap_model(org_model, "OPTModel", "model")
shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model")
row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens'
row_layer_for_check = [
"decoder.layers[0].self_attn.q_proj",
"decoder.embed_tokens",
] # 'decoder.embed_tokens'
col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"]
# Save gradient tensors for comparison between the original model and the sharded model.
......@@ -50,10 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 4e-2, 4e-2
row_layer_grads = get_grad_tensors_for_check(
opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
......@@ -80,7 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
check_weight(
opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads
......@@ -110,8 +140,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
......@@ -135,7 +177,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
def run_opt_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_opt")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......@@ -169,7 +217,13 @@ def run_opt_test(test_config):
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_opt")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......@@ -178,13 +232,27 @@ def run_opt_3d_test(test_config):
def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_opt_test()
def check_opt_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_opt_3d_test()
......
......@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster,
)
stage_manager = booster.plugin.stage_manager
......@@ -71,7 +77,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
check_weight(
t5,
sharded_t5,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
......@@ -104,7 +119,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
......@@ -117,7 +132,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,
"pp_size": 1,
......@@ -144,7 +158,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
def run_t5_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
# skip 4-stage pp test for t5_encoder
if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model":
continue
......@@ -185,7 +205,13 @@ def run_t5_test(test_config):
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......@@ -194,13 +220,27 @@ def run_t5_3d_test(test_config):
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_t5_test()
def check_t5_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_t5_3d_test()
......
import math
import pytest
import torch
from einops import rearrange
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
def attention_ref(q, k, v, attn_mask=None, causal=False):
"""
attention output of the control group
"""
dtype_og = q.dtype
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
d = q.shape[-1]
scale = 1.0 / math.sqrt(d)
scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
if attn_mask is not None:
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
output = torch.einsum("bhts,bshd->bthd", attention, v)
output = rearrange(output, "b s h d -> b s (h d)")
# Modify the data at the positions of the mask to 0
if attn_mask is not None:
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
return output.to(dtype=dtype_og)
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_gpt(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
out_ref = attention_ref(q, k, v, mask, causal=True)
# check gradients
dy = torch.rand_like(y)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_bert(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
# attention mask of shape [B, S] with zero padding to max length S
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
assert list(y.shape) == [B, S, D]
out_ref = attention_ref(q, k, v, mask, causal=False)
dy = torch.rand_like(y)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize("proj_shape", [(6, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_attention_no_mask(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v)
assert list(y.shape) == [B, S, D]
out_ref = attention_ref(q, k, v, None, causal=False)
dy = torch.rand_like(y)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
@parameterize("dtype", DTYPE)
@parameterize("dropout", [0.0])
def test_cross_attention(proj_shape, dtype, dropout):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
assert list(y.shape) == [B, T, D]
out_ref = attention_ref(q, k, v, None, causal=True)
dy = torch.rand_like(y)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment