Commit 4a40151b authored by chenych's avatar chenych
Browse files

Update v0.8.3

parent 731cf9b8
......@@ -35,7 +35,6 @@ from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
......@@ -51,15 +50,14 @@ transformers_logger = logging.get_logger(__name__)
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
......@@ -70,11 +68,7 @@ def llama_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
......@@ -136,15 +130,14 @@ def llama_attention_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
......@@ -158,11 +151,7 @@ def llama_flash_attention_2_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
......@@ -209,24 +198,9 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
......@@ -251,15 +225,14 @@ def llama_flash_attention_2_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
......@@ -285,11 +258,7 @@ def llama_sdpa_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
......@@ -353,7 +322,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
......@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
......@@ -114,15 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
......
......@@ -21,6 +21,7 @@ from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
......@@ -88,6 +89,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "chatglm":
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
......
......@@ -162,12 +162,10 @@ class PissaConvertCallback(TrainerCallback):
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default")
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
......
......@@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
......@@ -106,7 +106,7 @@ class CustomDPOTrainer(DPOTrainer):
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
......
......@@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
......@@ -101,7 +101,7 @@ class CustomKTOTrainer(KTOTrainer):
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
......
......@@ -39,7 +39,7 @@ from trl.models.utils import unwrap_model_for_generation
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
......@@ -133,7 +133,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
ref_model=ref_model,
tokenizer=tokenizer,
dataset=train_dataset,
optimizer=optimizer,
data_collator=data_collator,
lr_scheduler=scheduler,
)
......@@ -304,7 +303,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
optimizer = create_custom_optimizer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model)
......
......@@ -19,7 +19,7 @@ from transformers import Trainer
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
......@@ -57,7 +57,7 @@ class CustomTrainer(Trainer):
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
......
......@@ -25,7 +25,7 @@ from transformers import Trainer
from ...extras.logging import get_logger
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
......@@ -65,7 +65,7 @@ class PairwiseTrainer(Trainer):
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
......
......@@ -27,7 +27,7 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
......@@ -66,7 +66,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
......
......@@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
......@@ -367,32 +366,7 @@ def _create_badam_optimizer(
return optimizer
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
num_kv_head = getattr(model.config, "num_key_value_heads", None)
optimizer = Adam_mini(
named_parameters=model.named_parameters(),
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(),
dim=hidden_size,
n_heads=num_q_head,
n_kv_heads=num_kv_head,
)
logger.info("Using Adam-mini optimizer.")
return optimizer
def create_custom_optimizer(
def create_custom_optimzer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
......@@ -406,9 +380,6 @@ def create_custom_optimizer(
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_adam_mini:
return _create_adam_mini_optimizer(model, training_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
......
......@@ -66,7 +66,7 @@ def save_model(
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang]
if error:
......@@ -104,7 +104,7 @@ def save_model(
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_device = gr.Radio(choices=["cpu", "auto"], value="cpu")
......
......@@ -33,7 +33,7 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
......
......@@ -71,7 +71,7 @@ def create_web_demo() -> "gr.Blocks":
engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang))
_, _, chat_elems = create_chat_box(engine, visible=True)
......
This diff is collapsed.
......@@ -104,6 +104,11 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
......@@ -111,6 +116,8 @@ class Runner:
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
......@@ -159,11 +166,6 @@ class Runner:
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
# freeze config
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
......@@ -240,12 +242,18 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
......@@ -269,7 +277,6 @@ class Runner:
else:
args["do_eval"] = True
# checkpoints
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
......@@ -278,11 +285,6 @@ class Runner:
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment