"docs/source/vscode:/vscode.git/clone" did not exist on "e3342edc4ef2f6046595ab2d5660640a313501fe"
Unverified Commit 9cf4f2aa authored by RhuiDih's avatar RhuiDih Committed by GitHub
Browse files

Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs (#31629)

* add DataCollatorBatchFlattening

* Update data_collator.py

* change name

* new FA2 flow if position_ids is provided

* add comments

* minor fix

* minor fix data collator

* add test cases for models

* add test case for data collator

* remove extra code

* formating for ruff check and check_repo.py

* ruff format

ruff format tests src utils

* custom_init_isort.py
parent 7d92009a
......@@ -66,3 +66,8 @@ Examples of use can be found in the [example scripts](../examples) or [example n
- numpy_mask_tokens
- tf_mask_tokens
- torch_mask_tokens
## DataCollatorWithFlattening
[[autodoc]] data.data_collator.DataCollatorWithFlattening
......@@ -103,6 +103,7 @@ _import_structure = {
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithFlattening",
"DataCollatorWithPadding",
"DefaultDataCollator",
"default_data_collator",
......@@ -4764,6 +4765,7 @@ if TYPE_CHECKING:
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
......
......@@ -19,6 +19,7 @@ from .data_collator import (
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
......
......@@ -1611,3 +1611,38 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
) & masked_indices[i]
return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""
def __init__(self, *args, return_position_ids=True, **kwargs):
super().__init__(*args, **kwargs)
self.return_position_ids = return_position_ids
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)
......@@ -130,6 +130,56 @@ def _upad_input(
)
def prepare_fa2_from_position_ids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cummulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cummulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = position_ids.max() + 1
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
......@@ -138,6 +188,7 @@ def _flash_attention_forward(
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
......@@ -210,6 +261,34 @@ def _flash_attention_forward(
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# if position_ids is provided and check not all examples (row) contain only 1 sequence,
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
......
......@@ -415,6 +415,7 @@ class DbrxFlashAttention2(DbrxAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -602,6 +602,7 @@ class FalconFlashAttention2(FalconAttention):
value_layer,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -393,6 +393,7 @@ class GemmaFlashAttention2(GemmaAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
......
......@@ -503,6 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -382,6 +382,7 @@ class MistralFlashAttention2(MistralAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -488,6 +488,7 @@ class MixtralFlashAttention2(MixtralAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
......
......@@ -428,6 +428,7 @@ class OlmoFlashAttention2(OlmoAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
......
......@@ -501,6 +501,7 @@ class PhiFlashAttention2(PhiAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -563,6 +563,7 @@ class Phi3FlashAttention2(Phi3Attention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=attn_dropout,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
......
......@@ -429,6 +429,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
......
......@@ -508,6 +508,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
......
......@@ -606,6 +606,7 @@ class StableLmFlashAttention2(StableLmAttention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
......
......@@ -404,6 +404,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
......
......@@ -4327,6 +4327,78 @@ class ModelTesterMixin:
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)
# flatten
padfree_inputs_dict = {
k: v[dummy_attention_mask.bool()].unsqueeze(0)
for k, v in inputs_dict.items()
if not k == "attention_mask"
}
# add position_ids
padfree_inputs_dict["position_ids"] = (
torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
.long()
.unsqueeze(0)
.to(torch_device)
)
res_padded = model(**inputs_dict)
res_padfree = model(**padfree_inputs_dict)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
@is_pt_tf_cross_test
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
......
......@@ -26,6 +26,7 @@ from transformers import (
DataCollatorForSeq2Seq,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithFlattening,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
......@@ -1531,6 +1532,24 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (2, 8))
def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment