Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
...@@ -16,7 +16,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): ...@@ -16,7 +16,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
if self.tokenizer.padding_side == "left": if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0) start, end = feature.size(0) - answer_len, feature.size(0)
else: else:
start, end = prompt_len, answer_len start, end = prompt_len, prompt_len + answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature) padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end] padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor) padded_labels.append(padded_tensor)
......
import torch import torch
from collections import defaultdict from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
# from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
finetuning_args: "FinetuningArguments", beta: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs **kwargs
): ):
self.finetuning_args = finetuning_args # if disable_dropout:
# disable_dropout_in_model(model)
# if ref_model is not None:
# disable_dropout_in_model(ref_model)
self.ref_model = ref_model self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.beta = finetuning_args.dpo_beta self.beta = beta
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs) Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"): if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.") raise AttributeError("Please update `transformers`.")
if ref_model is not None: if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward( def concatenated_forward(
...@@ -42,28 +50,13 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): ...@@ -42,28 +50,13 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
batch: Optional[Dict[str, torch.Tensor]] = None batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model( all_logits = model(
input_ids=batch_copied["input_ids"], input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"], attention_mask=batch_copied["attention_mask"],
return_dict=True return_dict=True
).logits.to(torch.float32) ).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps( all_logps = self._get_batch_logps(
all_logits, all_logits,
batch["labels"], batch["labels"],
......
...@@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX ...@@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
...@@ -37,10 +37,10 @@ def run_dpo( ...@@ -37,10 +37,10 @@ def run_dpo(
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = DPOPeftTrainer( trainer = CustomDPOTrainer(
finetuning_args=finetuning_args, beta=finetuning_args.dpo_beta,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
model=model, model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
......
...@@ -2,29 +2,27 @@ import os ...@@ -2,29 +2,27 @@ import os
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback from llmtuner.hparams import GeneratingArguments
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer): class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""
Inherits PPOTrainer. Inherits PPOTrainer.
""" """
...@@ -32,17 +30,18 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -32,17 +30,18 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def __init__( def __init__(
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["LogCallback"], callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype, compute_dtype: torch.dtype,
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args self.generating_args = generating_args
self.log_callback = callbacks[0] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
...@@ -75,10 +74,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -75,10 +74,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)) generating_args.update(dict(
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
gen_kwargs["logits_processor"] = get_logits_processor() pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
...@@ -96,17 +96,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -96,17 +96,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to inference mode # Cast to inference mode
unwrapped_model.gradient_checkpointing_disable() unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True unwrapped_model.config.use_cache = True
unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype)
self.model.eval()
# Get inputs # Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) queries, responses = self.get_inputs(batch, length_sampler, generating_args)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model) rewards = self.get_rewards(queries, responses, unwrapped_model)
# Cast to training mode # Cast to training mode
unwrapped_model.gradient_checkpointing_enable() unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False unwrapped_model.config.use_cache = False
unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params)
self.model.train()
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
...@@ -137,36 +143,44 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -137,36 +143,44 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0 steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control) self.log_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad() @torch.no_grad()
def get_inputs( def get_inputs(
self, self,
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None, length_sampler: Callable,
**generation_kwargs generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
if length_sampler is not None: generating_args["max_new_tokens"] = length_sampler()
generation_kwargs["max_new_tokens"] = length_sampler() gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
**batch
)
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype) input_ids = batch["input_ids"]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params) query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
queries, responses = [], [] queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
else:
response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
...@@ -191,7 +205,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -191,7 +205,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
return rewards return rewards
...@@ -202,7 +220,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -202,7 +220,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor, queries: torch.Tensor,
responses: torch.Tensor, responses: torch.Tensor,
model_inputs: dict, model_inputs: dict,
return_logits: Optional[bool] = False return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
): ):
r""" r"""
Calculates model outputs in multiple batches. Calculates model outputs in multiple batches.
...@@ -220,6 +239,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -220,6 +239,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs] query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"] input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"] attention_mask = input_kwargs["attention_mask"]
...@@ -239,8 +260,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -239,8 +260,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
start += attention_mask[j, :].nonzero()[0] start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j]) end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0 masks[j, :start] = 0
masks[j, end:] = 0 masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits: if return_logits:
all_logits.append(logits) all_logits.append(logits)
...@@ -266,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): ...@@ -266,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
""" """
if self.args.should_save: if self.args.should_save:
self._save(output_dir) self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
...@@ -4,14 +4,14 @@ import math ...@@ -4,14 +4,14 @@ import math
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
...@@ -29,7 +29,9 @@ def run_ppo( ...@@ -29,7 +29,9 @@ def run_ppo(
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,
...@@ -44,7 +46,6 @@ def run_ppo( ...@@ -44,7 +46,6 @@ def run_ppo(
) )
if finetuning_args.ppo_score_norm: if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True ppo_config.use_score_norm = True
...@@ -61,11 +62,10 @@ def run_ppo( ...@@ -61,11 +62,10 @@ def run_ppo(
) )
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = PPOPeftTrainer( ppo_trainer = CustomPPOTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks, callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
config=ppo_config, config=ppo_config,
model=model, model=model,
......
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
...@@ -27,8 +26,7 @@ def run_pt( ...@@ -27,8 +26,7 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = Trainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment