Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -39,8 +39,9 @@ class SFTTrainer(SLTrainer): ...@@ -39,8 +39,9 @@ class SFTTrainer(SLTrainer):
accumulation_steps: int = 8, accumulation_steps: int = 8,
) -> None: ) -> None:
if accumulation_steps > 1: if accumulation_steps > 1:
assert not isinstance(strategy, GeminiStrategy), \ assert not isinstance(
"Accumulation steps are not supported in stage 3 of ColossalAI" strategy, GeminiStrategy
), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim) super().__init__(strategy, max_epochs, model, optim)
...@@ -50,15 +51,11 @@ class SFTTrainer(SLTrainer): ...@@ -50,15 +51,11 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int): def _train(self, epoch: int):
self.model.train() self.model.train()
for batch_id, batch in enumerate(self.train_dataloader): for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
if "attention_mask" in batch: if "attention_mask" in batch:
outputs = self.model(batch["input_ids"], outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
attention_mask=batch["attention_mask"],
labels=batch["labels"])
else: else:
outputs = self.model(batch["input_ids"], outputs = self.model(batch["input_ids"], labels=batch["labels"])
labels=batch["labels"])
loss = outputs.loss loss = outputs.loss
loss = loss / self.accumulation_steps loss = loss / self.accumulation_steps
...@@ -73,12 +70,14 @@ class SFTTrainer(SLTrainer): ...@@ -73,12 +70,14 @@ class SFTTrainer(SLTrainer):
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() self.scheduler.step()
if is_rank_0() and self.use_wandb: if is_rank_0() and self.use_wandb:
wandb.log({ wandb.log(
"loss": self.total_loss / self.accumulation_steps, {
"lr": self.scheduler.get_last_lr()[0], "loss": self.total_loss / self.accumulation_steps,
"epoch": epoch, "lr": self.scheduler.get_last_lr()[0],
"batch_id": batch_id "epoch": epoch,
}) "batch_id": batch_id,
}
)
self.total_loss = 0 self.total_loss = 0
self.step_bar.update() self.step_bar.update()
...@@ -89,9 +88,9 @@ class SFTTrainer(SLTrainer): ...@@ -89,9 +88,9 @@ class SFTTrainer(SLTrainer):
loss_sum, num_seen = 0, 0 loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader: for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device()) batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], outputs = self.model(
attention_mask=batch["attention_mask"], batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
labels=batch["labels"]) )
loss = outputs.loss loss = outputs.loss
loss_sum += loss.item() loss_sum += loss.item()
...@@ -99,13 +98,15 @@ class SFTTrainer(SLTrainer): ...@@ -99,13 +98,15 @@ class SFTTrainer(SLTrainer):
loss_mean = loss_sum / num_seen loss_mean = loss_sum / num_seen
if dist.get_rank() == 0: if dist.get_rank() == 0:
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
def _before_fit(self, def _before_fit(
train_dataloader: DataLoader, self,
eval_dataloader: Optional[DataLoader] = None, train_dataloader: DataLoader,
logger: Optional[DistributedLogger] = None, eval_dataloader: Optional[DataLoader] = None,
use_wandb: bool = False): logger: Optional[DistributedLogger] = None,
use_wandb: bool = False,
):
""" """
Args: Args:
train_dataloader: the dataloader to use for training train_dataloader: the dataloader to use for training
...@@ -124,6 +125,6 @@ class SFTTrainer(SLTrainer): ...@@ -124,6 +125,6 @@ class SFTTrainer(SLTrainer):
self.no_epoch_bar = True self.no_epoch_bar = True
self.step_bar = tqdm.trange( self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs, len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
desc=f'steps', desc=f"steps",
disable=not is_rank_0() disable=not is_rank_0(),
) )
...@@ -2,7 +2,4 @@ from .base import Strategy ...@@ -2,7 +2,4 @@ from .base import Strategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy from .ddp import DDPStrategy
__all__ = [ __all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
'Strategy', 'DDPStrategy',
'LowLevelZeroStrategy', 'GeminiStrategy'
]
...@@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict] ...@@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC): class Strategy(ABC):
""" """
Base class for training strategies. Base class for training strategies.
""" """
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None: def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
...@@ -83,16 +83,18 @@ class Strategy(ABC): ...@@ -83,16 +83,18 @@ class Strategy(ABC):
rets.append((model, optimizer)) rets.append((model, optimizer))
elif isinstance(arg, Dict): elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model, boost_result = dict(
optimizer=optimizer, model=model,
criterion=criterion, optimizer=optimizer,
dataloader=dataloader, criterion=criterion,
lr_scheduler=lr_scheduler) dataloader=dataloader,
lr_scheduler=lr_scheduler,
)
# remove None values # remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None} boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result) rets.append(boost_result)
else: else:
raise RuntimeError(f'Type {type(arg)} is not supported') raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets return rets[0] if len(rets) == 1 else rets
...@@ -125,11 +127,9 @@ class Strategy(ABC): ...@@ -125,11 +127,9 @@ class Strategy(ABC):
return DistributedSampler(dataset, 1, 0) return DistributedSampler(dataset, 1, 0)
@abstractmethod @abstractmethod
def save_pretrained(self, def save_pretrained(
model: nn.Module, self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
path: str, ) -> None:
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass pass
@abstractmethod @abstractmethod
......
...@@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy): ...@@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy):
""" """
def __init__(self, def __init__(
stage: int = 2, self,
precision: str = 'fp16', stage: int = 2,
seed: int = 42, precision: str = "fp16",
placement_policy: str = 'cuda', seed: int = 42,
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 placement_policy: str = "cuda",
overlap_communication: bool = True, # only for stage 1&2 reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
initial_scale: float = 2**16, overlap_communication: bool = True, # only for stage 1&2
growth_factor: float = 2, initial_scale: float = 2**16,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
hysteresis: int = 2, growth_interval: int = 1000,
min_scale: float = 1, hysteresis: int = 2,
max_scale: float = 2**32, min_scale: float = 1,
max_norm: float = 0.0, max_scale: float = 2**32,
norm_type: float = 2.0 max_norm: float = 0.0,
) -> None: norm_type: float = 2.0,
) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"' assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin( plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config # zero_config
...@@ -71,7 +71,7 @@ class LowLevelZeroStrategy(DDPStrategy): ...@@ -71,7 +71,7 @@ class LowLevelZeroStrategy(DDPStrategy):
# zero_optim_config # zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size, reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'), cpu_offload=(placement_policy == "cpu"),
# optim_config # optim_config
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
...@@ -81,14 +81,15 @@ class LowLevelZeroStrategy(DDPStrategy): ...@@ -81,14 +81,15 @@ class LowLevelZeroStrategy(DDPStrategy):
min_scale=min_scale, min_scale=min_scale,
max_scale=max_scale, max_scale=max_scale,
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type norm_type=norm_type,
) )
super().__init__(seed, plugin_initializer) super().__init__(seed, plugin_initializer)
def _post_init(self) -> None: def _post_init(self) -> None:
assert isinstance(self.plugin, LowLevelZeroPlugin), \ assert isinstance(
f'{type(self).__name__}\'s plugin is not initialized properly.' self.plugin, LowLevelZeroPlugin
), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed) colossalai.launch_from_torch({}, seed=self.seed)
...@@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy): ...@@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy):
""" """
def __init__(self, def __init__(
seed: int = 42, self,
shard_init: bool = False, # only for stage 3 seed: int = 42,
placement_policy: str = 'cuda', shard_init: bool = False, # only for stage 3
pin_memory: bool = True, # only for stage 3 placement_policy: str = "cuda",
force_outputs_fp32: bool = False, # only for stage 3 pin_memory: bool = True, # only for stage 3
search_range_m: int = 32, # only for stage 3 force_outputs_fp32: bool = False, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3 search_range_m: int = 32, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3 hidden_dim: Optional[int] = None, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3 min_chunk_size_m: float = 32, # only for stage 3
initial_scale: float = 2**16, gpu_margin_mem_ratio: float = 0.0, # only for stage 3
growth_factor: float = 2, initial_scale: float = 2**16,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
hysteresis: int = 2, growth_interval: int = 1000,
min_scale: float = 1, hysteresis: int = 2,
max_scale: float = 2**32, min_scale: float = 1,
max_norm: float = 0.0, max_scale: float = 2**32,
norm_type: float = 2.0 max_norm: float = 0.0,
) -> None: norm_type: float = 2.0,
) -> None:
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained() # TODO(ver217): support shard_init when using from_pretrained()
if shard_init: if shard_init:
warnings.warn( warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. ' f"Shard init is not supported model.from_pretrained() yet. "
'Please load weights after strategy.prepare()' "Please load weights after strategy.prepare()"
) )
self.shard_init = shard_init self.shard_init = shard_init
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# NOTE: dist should be initialized before calling get_current_device() # NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin( plugin_initializer = lambda: GeminiPlugin(
# gemini_config # gemini_config
device=get_current_device(), device=get_current_device(),
placement_policy=placement_policy, placement_policy=placement_policy,
precision='fp16', precision="fp16",
pin_memory=pin_memory, pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32, force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init, strict_ddp_mode=shard_init,
...@@ -187,14 +188,13 @@ class GeminiStrategy(DDPStrategy): ...@@ -187,14 +188,13 @@ class GeminiStrategy(DDPStrategy):
min_scale=min_scale, min_scale=min_scale,
max_scale=max_scale, max_scale=max_scale,
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type norm_type=norm_type,
) )
super().__init__(seed, plugin_initializer) super().__init__(seed, plugin_initializer)
def _post_init(self) -> None: def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), \ assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed) colossalai.launch_from_torch({}, seed=self.seed)
...@@ -203,10 +203,9 @@ class GeminiStrategy(DDPStrategy): ...@@ -203,10 +203,9 @@ class GeminiStrategy(DDPStrategy):
world_size = dist.get_world_size() world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(), return ColoInitContext(
dtype=torch.half, device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
default_pg=shard_pg, )
default_dist_spec=default_dist_spec)
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel) assert isinstance(model, GeminiModel)
......
...@@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module): ...@@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
class DDPStrategy(Strategy): class DDPStrategy(Strategy):
""" """
Strategy for distributed training using torch.distributed. Strategy for distributed training using torch.distributed.
""" """
def __init__(self, def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
seed: int = 42,
plugin_initializer: Callable = TorchDDPPlugin
) -> None:
self.seed = seed self.seed = seed
super().__init__(plugin_initializer) super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None: def _try_init_dist(self, force: bool = False) -> None:
try: try:
rank = int(os.environ['RANK']) rank = int(os.environ["RANK"])
local_rank = int(os.environ['LOCAL_RANK']) local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ['WORLD_SIZE']) world_size = int(os.environ["WORLD_SIZE"])
host = os.environ['MASTER_ADDR'] host = os.environ["MASTER_ADDR"]
port = int(os.environ['MASTER_PORT']) port = int(os.environ["MASTER_PORT"])
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
except KeyError as e: except KeyError as e:
if force: if force:
...@@ -60,8 +57,7 @@ class DDPStrategy(Strategy): ...@@ -60,8 +57,7 @@ class DDPStrategy(Strategy):
raise e raise e
def _post_init(self) -> None: def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), \ assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
f'{type(self).__name__}\'s plugin is not initialized properly.'
def setup_distributed(self) -> None: def setup_distributed(self) -> None:
self._try_init_dist(force=True) self._try_init_dist(force=True)
...@@ -73,12 +69,14 @@ class DDPStrategy(Strategy): ...@@ -73,12 +69,14 @@ class DDPStrategy(Strategy):
torch.manual_seed(seed) torch.manual_seed(seed)
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(data_buffer, return self.plugin.prepare_dataloader(
batch_size=data_buffer.sample_batch_size, data_buffer,
shuffle=True, batch_size=data_buffer.sample_batch_size,
drop_last=True, shuffle=True,
pin_memory=pin_memory, drop_last=True,
collate_fn=data_buffer.collate_fn) pin_memory=pin_memory,
collate_fn=data_buffer.collate_fn,
)
def setup_sampler(self, dataset) -> DistributedSampler: def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
...@@ -88,11 +86,9 @@ class DDPStrategy(Strategy): ...@@ -88,11 +86,9 @@ class DDPStrategy(Strategy):
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel." assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap() return model.unwrap()
def save_pretrained(self, def save_pretrained(
model: nn.Module, self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
path: str, ) -> None:
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if not only_rank0 or dist.get_rank() == 0: if not only_rank0 or dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model) unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
...@@ -103,17 +99,11 @@ class DDPStrategy(Strategy): ...@@ -103,17 +99,11 @@ class DDPStrategy(Strategy):
if tokenizer is not None: if tokenizer is not None:
tokenizer.save_pretrained(path) tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin") model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, self.save_model(model, model_path, only_rank0=only_rank0)
model_path,
only_rank0=only_rank0)
def _replace_keys(model_path: str, def _replace_keys(model_path: str, replace_fn: Callable):
replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu") state_dict = torch.load(model_path, map_location="cpu")
state_dict = { state_dict = {replace_fn(k): v for k, v in state_dict.items()}
replace_fn(k): v
for k, v in state_dict.items()
}
torch.save(state_dict, model_path) torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
...@@ -124,13 +114,13 @@ class DDPStrategy(Strategy): ...@@ -124,13 +114,13 @@ class DDPStrategy(Strategy):
def get_model_state_dict_shard(self, model: nn.Module, **config): def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy # TODO: implement sharding on naive strategy
model = self.unwrap_model(model) model = self.unwrap_model(model)
if 'requires_grad_only' in config and config['requires_grad_only'] == True: if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model) state_dict = get_grad_required_state_dict(model)
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
if 'shard_size' in config: if "shard_size" in config:
shard_size = config['shard_size'] shard_size = config["shard_size"]
accumulate_size = 0 accumulate_size = 0
state_dict_shard = OrderedDict() state_dict_shard = OrderedDict()
for name, param in state_dict.items(): for name, param in state_dict.items():
......
...@@ -4,7 +4,6 @@ import numpy as np ...@@ -4,7 +4,6 @@ import numpy as np
class DistributedSampler: class DistributedSampler:
def __init__(self, dataset, num_replicas: int, rank: int) -> None: def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset self.dataset = dataset
self.num_replicas = num_replicas self.num_replicas = num_replicas
...@@ -12,7 +11,7 @@ class DistributedSampler: ...@@ -12,7 +11,7 @@ class DistributedSampler:
if len(self.dataset) % self.num_replicas != 0: if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil( self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
) )
else: else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
...@@ -20,10 +19,10 @@ class DistributedSampler: ...@@ -20,10 +19,10 @@ class DistributedSampler:
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset))) indices = list(range(len(self.dataset)))
indices = indices[:self.total_size] indices = indices[: self.total_size]
assert len(indices) == self.total_size assert len(indices) == self.total_size
# subsample # subsample
indices = indices[self.rank:self.total_size:self.num_replicas] indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
self.indices = indices self.indices = indices
......
...@@ -42,7 +42,6 @@ def is_rank_0() -> bool: ...@@ -42,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any: def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any): def _to(t: Any):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
return t.to(device) return t.to(device)
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
"BLEU", "BLEU",
"ROUGE", "ROUGE",
"BERTScore" "BERTScore"
] ]
}, },
"logical_reasoning": { "logical_reasoning": {
"GPT": [ "GPT": [
...@@ -83,7 +83,7 @@ ...@@ -83,7 +83,7 @@
"ROUGE", "ROUGE",
"BERTScore", "BERTScore",
"CHRF" "CHRF"
] ]
}, },
"open_qa": { "open_qa": {
"GPT": [ "GPT": [
...@@ -126,7 +126,7 @@ ...@@ -126,7 +126,7 @@
"conciseness" "conciseness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Finance": { "Finance": {
"GPT": [ "GPT": [
...@@ -134,7 +134,7 @@ ...@@ -134,7 +134,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Law": { "Law": {
"GPT": [ "GPT": [
...@@ -142,7 +142,7 @@ ...@@ -142,7 +142,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Education": { "Education": {
"GPT": [ "GPT": [
...@@ -150,7 +150,7 @@ ...@@ -150,7 +150,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Medical": { "Medical": {
"GPT": [ "GPT": [
...@@ -158,7 +158,7 @@ ...@@ -158,7 +158,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"STEM": { "STEM": {
"GPT": [ "GPT": [
...@@ -166,7 +166,7 @@ ...@@ -166,7 +166,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"SocialScience": { "SocialScience": {
"GPT": [ "GPT": [
...@@ -174,7 +174,7 @@ ...@@ -174,7 +174,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Humanity": { "Humanity": {
"GPT": [ "GPT": [
...@@ -182,7 +182,7 @@ ...@@ -182,7 +182,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"Other": { "Other": {
"GPT": [ "GPT": [
...@@ -190,7 +190,7 @@ ...@@ -190,7 +190,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
}, },
"ethics": { "ethics": {
"GPT": [ "GPT": [
...@@ -198,7 +198,7 @@ ...@@ -198,7 +198,7 @@
"correctness" "correctness"
], ],
"Metrics": [ "Metrics": [
] ]
} }
} }
} }
import argparse import argparse
import json
import os import os
import openai import openai
...@@ -9,7 +8,8 @@ from utils import jload ...@@ -9,7 +8,8 @@ from utils import jload
def main(args): def main(args):
assert len(args.answer_file_list) == len( assert len(args.answer_file_list) == len(
args.model_name_list), "The number of answer files and model names should be equal!" args.model_name_list
), "The number of answer files and model names should be equal!"
# load config # load config
config = jload(args.config_file) config = jload(args.config_file)
...@@ -36,7 +36,8 @@ def main(args): ...@@ -36,7 +36,8 @@ def main(args):
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt: if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
raise Exception( raise Exception(
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!") "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
)
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference: if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
raise Exception( raise Exception(
...@@ -44,8 +45,15 @@ def main(args): ...@@ -44,8 +45,15 @@ def main(args):
) )
# initialize evaluator # initialize evaluator
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model, evaluator = Evaluator(
config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference) metrics_per_category,
battle_prompt,
gpt_evaluation_prompt,
args.gpt_model,
config["language"],
config.get("path_for_UniEval", None),
args.gpt_with_reference,
)
if len(args.model_name_list) == 2: if len(args.model_name_list) == 2:
answers1 = jload(args.answer_file_list[0]) answers1 = jload(args.answer_file_list[0])
answers2 = jload(args.answer_file_list[1]) answers2 = jload(args.answer_file_list[1])
...@@ -68,41 +76,41 @@ def main(args): ...@@ -68,41 +76,41 @@ def main(args):
raise ValueError(f'Unsupported language {config["language"]}!') raise ValueError(f'Unsupported language {config["language"]}!')
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.') parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
parser.add_argument('--config_file', parser.add_argument(
type=str, "--config_file", type=str, default=None, required=True, help="path to the file of target results"
default=None, )
required=True, parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
help='path to the file of target results') parser.add_argument(
parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle') "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
parser.add_argument('--gpt_evaluation_prompt_file', )
type=str, parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
default=None, parser.add_argument(
help='path to the prompt file for gpt evaluation') "--answer_file_list",
parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file') type=str,
parser.add_argument('--answer_file_list', nargs="+",
type=str, default=[],
nargs='+', required=True,
default=[], help="path to the answer files of at most 2 models",
required=True, )
help='path to the answer files of at most 2 models') parser.add_argument(
parser.add_argument('--model_name_list', "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
type=str, )
nargs='+', parser.add_argument(
default=[], "--gpt_model",
required=True, default="gpt-3.5-turbo",
help='the names of at most 2 models') choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
parser.add_argument('--gpt_model', help="which GPT model to use for evaluation",
default="gpt-3.5-turbo", )
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], parser.add_argument(
help='which GPT model to use for evaluation') "--gpt_with_reference",
parser.add_argument('--gpt_with_reference', default=False,
default=False, action="store_true",
action="store_true", help="whether to include reference answer in gpt evaluation",
help='whether to include reference answer in gpt evaluation') )
parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results') parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key') parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
args = parser.parse_args() args = parser.parse_args()
if args.openai_key is not None: if args.openai_key is not None:
......
...@@ -3,20 +3,27 @@ from typing import Any, Dict, List ...@@ -3,20 +3,27 @@ from typing import Any, Dict, List
import gpt_evaluate import gpt_evaluate
import metrics import metrics
import pandas as pd
import unieval import unieval
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
class Evaluator(object): class Evaluator(object):
""" """
A class named Evaluator includes GPT-3.5/GPT-4 evaluation A class named Evaluator includes GPT-3.5/GPT-4 evaluation
and automatic evaluation and automatic evaluation
""" """
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any], def __init__(
gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None: self,
params: Dict[str, Any],
battle_prompt: Dict[str, Any],
gpt_evaluation_prompt: Dict[str, Any],
gpt_model: str,
language: str,
path_for_UniEval: Dict[str, str],
gpt_with_reference: bool,
) -> None:
self.params = params self.params = params
self.battle_prompt = battle_prompt self.battle_prompt = battle_prompt
self.gpt_evaluation_prompt = gpt_evaluation_prompt self.gpt_evaluation_prompt = gpt_evaluation_prompt
...@@ -103,7 +110,8 @@ class Evaluator(object): ...@@ -103,7 +110,8 @@ class Evaluator(object):
if self.params[category]["UniEval"] and self.language == "cn": if self.params[category]["UniEval"] and self.language == "cn":
raise Exception( raise Exception(
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.") "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file."
)
category_metrics = self.params[category]["UniEval"] category_metrics = self.params[category]["UniEval"]
...@@ -134,10 +142,9 @@ class Evaluator(object): ...@@ -134,10 +142,9 @@ class Evaluator(object):
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]] sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list) data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
scores = uni_evaluator.evaluate(data, scores = uni_evaluator.evaluate(
category, data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False
dims=list(self.unieval_metric_stats[task][category].keys()), )
overall=False)
avg_scores = unieval.calculate_average_score(scores) avg_scores = unieval.calculate_average_score(scores)
self.unieval_metric_stats[task][category].update(avg_scores) self.unieval_metric_stats[task][category].update(avg_scores)
...@@ -165,7 +172,8 @@ class Evaluator(object): ...@@ -165,7 +172,8 @@ class Evaluator(object):
category, category,
self.gpt_model, self.gpt_model,
self.language, self.language,
references=targets_per_category[category] if self.gpt_with_reference else None) references=targets_per_category[category] if self.gpt_with_reference else None,
)
def save(self, path: str, model_name_list: List[str]) -> None: def save(self, path: str, model_name_list: List[str]) -> None:
""" """
...@@ -204,16 +212,18 @@ class Evaluator(object): ...@@ -204,16 +212,18 @@ class Evaluator(object):
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0], all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
self.gpt_evaluation_results, model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
gpt_evaluation_results_save_path) )
# Start to calculate scores and save statistics. # Start to calculate scores and save statistics.
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations, gpt_evaluate.save_gpt_evaluation_statistics(
gpt_evaluation_statistics_save_path) model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
)
# Save charts and csv. # Save charts and csv.
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path, gpt_evaluate.analyze_gpt_evaluation_statistics(
gpt_evaluation_analyses_save_path) gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
)
...@@ -14,20 +14,18 @@ import tqdm ...@@ -14,20 +14,18 @@ import tqdm
from utils import jdump, jload from utils import jdump, jload
ref_step_template = { ref_step_template = {
"en": "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
"Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
"cn":
"请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n"
} }
ref_answer_template_general = { ref_answer_template_general = {
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n", "en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n" "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
} }
ref_answer_template_correctness = { ref_answer_template_correctness = {
"en": "\nA correct answer is as follows:\n\n{answer}\n\n", "en": "\nA correct answer is as follows:\n\n{answer}\n\n",
"cn": "\n标准答案如下:\n\n{answer}\n\n" "cn": "\n标准答案如下:\n\n{answer}\n\n",
} }
...@@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in ...@@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="gpt-4", model="gpt-4",
messages=[ messages=[
{ {"role": "system", "content": sys_prompt},
"role": "system",
"content": sys_prompt
},
{ {
"role": "user", "role": "user",
"content": user_prompt, "content": user_prompt,
...@@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]: ...@@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
return [float(sp[0]), float(sp[1])] return [float(sp[0]), float(sp[1])]
else: else:
raise Exception(f"Invalid score pair. Got {evaluation}.") raise Exception(f"Invalid score pair. Got {evaluation}.")
except Exception as e: except Exception:
return [-1, -1] return [-1, -1]
...@@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] ...@@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert len(answer1) == len(answer2) assert len(answer1) == len(answer2)
handles = []
evaluation_file = []
total_len = len(answer1) total_len = len(answer1)
question_idx_list = list(range(total_len)) question_idx_list = list(range(total_len))
...@@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] ...@@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert answer1[i]["id"] == answer2[i]["id"] assert answer1[i]["id"] == answer2[i]["id"]
answer_id = answer1[i]["id"] answer_id = answer1[i]["id"]
ques = answer1[i]["instruction"] if answer1[i][ ques = (
"input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"] answer1[i]["instruction"]
cat = answer1[i]["category"] if answer1[i]["input"] == ""
else answer1[i]["instruction"] + " " + answer1[i]["input"]
)
answer1[i]["category"]
ans1 = answer1[i]["output"] ans1 = answer1[i]["output"]
ans2 = answer2[i]["output"] ans2 = answer2[i]["output"]
...@@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> ...@@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
step_to_add = ref_step_template[language] step_to_add = ref_step_template[language]
for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)" for_the_given_answer = (
"{metric} (1-5) (directly give the score for the given answer):"
if language == "en"
else "{metric} (1-5) (直接对给定答案打分)"
)
# adjective is used to describe the word "answer" in the prompt. # adjective is used to describe the word "answer" in the prompt.
adjective = "example" if language == "en" else "示例" adjective = "example" if language == "en" else "示例"
...@@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> ...@@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
answer_to_add = ref_answer_template_correctness[language] answer_to_add = ref_answer_template_correctness[language]
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"]) answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
step_to_add = step_to_add.format(metric=metric.lower(), step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
adjective=adjective) + for_the_given_answer.format(metric=metric) metric=metric
)
return answer_to_add + step_to_add return answer_to_add + step_to_add
...@@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: ...@@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
for j in range(i): for j in range(i):
messages_to_send.append(fill_in_message("user", user_messages[j])) messages_to_send.append(fill_in_message("user", user_messages[j]))
messages_to_send.append( messages_to_send.append(
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])) fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
)
# Length of user messages == Length of assistant messages + 1 # Length of user messages == Length of assistant messages + 1
# Because we always expect the api to response # Because we always expect the api to response
...@@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: ...@@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
return assistant_responses[-1] return assistant_responses[-1]
def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], def get_gpt_evaluation_without_logprobs(
inst: Dict[str, Any], prompt: Dict[str, Any],
metrics: List[str], inst: Dict[str, Any],
language: str, metrics: List[str],
reference: Dict[str, Any] = None, language: str,
model: str = "gpt-3.5-turbo", reference: Dict[str, Any] = None,
max_tokens: int = 2048) -> Dict[str, Any]: model: str = "gpt-3.5-turbo",
max_tokens: int = 2048,
) -> Dict[str, Any]:
""" """
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer. Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
...@@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], ...@@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3 MAX_API_RETRY = 3
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"] answer = inst["output"]
inst["evaluation"] = {} inst["evaluation"] = {}
...@@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], ...@@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
if prompt_reference: if prompt_reference:
# Do a 2-round conversation # Do a 2-round conversation
response = multiturn_chat_completion([prompt_1st_round, prompt_reference], response = multiturn_chat_completion(
model, [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
max_tokens=max_tokens, )
turns=2)
else: else:
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1) response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
...@@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], ...@@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
return inst return inst
def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], def get_gpt_evaluation_with_logprobs(
inst: Dict[str, Any], prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
metrics: List[str], ) -> Dict[str, Any]:
max_tokens: int = 2048) -> Dict[str, Any]:
""" """
Use completion model(text-davinci-003) to evaluate one model answer. Use completion model(text-davinci-003) to evaluate one model answer.
Only completion models can return log probabilities. Only completion models can return log probabilities.
...@@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], ...@@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3 MAX_API_RETRY = 3
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"] answer = inst["output"]
inst["evaluation"] = {} inst["evaluation"] = {}
...@@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], ...@@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
return inst return inst
def evaluate(answers: List[Dict], def evaluate(
prompt: Dict[str, Any], answers: List[Dict],
metrics: List[str], prompt: Dict[str, Any],
category: str, metrics: List[str],
model: str, category: str,
language: str, model: str,
references: List[Dict] = None) -> List[Dict]: language: str,
references: List[Dict] = None,
) -> List[Dict]:
""" """
Use GPT models to evaluate model answers and save evaluation results. Use GPT models to evaluate model answers and save evaluation results.
...@@ -529,21 +532,23 @@ def evaluate(answers: List[Dict], ...@@ -529,21 +532,23 @@ def evaluate(answers: List[Dict],
if model == "text-davinci-003": if model == "text-davinci-003":
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
else: else:
future = executor.submit(get_gpt_evaluation_without_logprobs, future = executor.submit(
prompt, get_gpt_evaluation_without_logprobs,
inst, prompt,
metrics, inst,
language, metrics,
reference=None if references is None else references[idx], language,
model=model, reference=None if references is None else references[idx],
max_tokens=1) model=model,
max_tokens=1,
)
futures.append(future) futures.append(future)
for future in tqdm.tqdm( for future in tqdm.tqdm(
concurrent.futures.as_completed(futures), concurrent.futures.as_completed(futures),
desc=f"{category}: ", desc=f"{category}: ",
total=len(futures), total=len(futures),
): ):
evaluations.append(future.result()) evaluations.append(future.result())
...@@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> ...@@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
return int(results[0]) return int(results[0])
else: else:
raise Exception(f"Invalid score pair. Got {evaluation}.") raise Exception(f"Invalid score pair. Got {evaluation}.")
except Exception as e: except Exception:
return 0 return 0
def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any], def save_gpt_evaluation_results(
save_path: str) -> Dict[str, Any]: model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
) -> Dict[str, Any]:
""" """
Save evaluation results for different categories for one model. Save evaluation results for different categories for one model.
...@@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav ...@@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
scores[metric].append(0) scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None: elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append( scores[metric].append(
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])) calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
)
else: else:
scores[metric].append( scores[metric].append(
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)) calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
)
statistics = {} statistics = {}
for metric in metrics: for metric in metrics:
...@@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N ...@@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv")) frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
for category in tqdm.tqdm( for category in tqdm.tqdm(
frame_per_category.keys(), frame_per_category.keys(),
desc=f"GPT evaluation: ", desc=f"GPT evaluation: ",
total=len(frame_per_category.keys()), total=len(frame_per_category.keys()),
): ):
data = pd.DataFrame(frame_per_category[category]) data = pd.DataFrame(frame_per_category[category])
......
...@@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, ...@@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
""" """
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0} bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
cumulative_bleu = [0] * 4 cumulative_bleu = [0] * 4
weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.), weights = [
(1. / 4., 1. / 4., 1. / 4., 1. / 4.)] (1.0 / 1.0, 0.0, 0.0, 0.0),
(1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0),
(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0),
(1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0),
]
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
if language == "cn": if language == "cn":
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()] target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()]
elif language == "en": elif language == "en":
pred_list = preprocessing_text(pred).split() pred_list = preprocessing_text(pred).split()
target_list = [preprocessing_text(target).split()] target_list = [preprocessing_text(target).split()]
...@@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, ...@@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
"""Calculate CHRF Score Metric in sentence level. """Calculate CHRF Score Metric in sentence level."""
"""
chrf_score = {"chrf": 0} chrf_score = {"chrf": 0}
cumulative_chrf = [] cumulative_chrf = []
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
if language == "cn": if language == "cn":
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
target_list = ' '.join(jieba.cut(preprocessing_text(target))).split() target_list = " ".join(jieba.cut(preprocessing_text(target))).split()
elif language == "en": elif language == "en":
pred_list = preprocessing_text(pred).split() pred_list = preprocessing_text(pred).split()
target_list = preprocessing_text(target).split() target_list = preprocessing_text(target).split()
...@@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]: ...@@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
all_targets = [] all_targets = []
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred)))) pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred))))
target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target)))) target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target))))
all_preds.append(pred_list) all_preds.append(pred_list)
all_targets.append(target_list) all_targets.append(target_list)
...@@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]: ...@@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
longest common subsequence (LCS) between preds and targets. longest common subsequence (LCS) between preds and targets.
""" """
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0} rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
all_preds = []
all_targets = []
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False) rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target)) score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
rouge_scores["rouge1"] += score['rouge1'].fmeasure rouge_scores["rouge1"] += score["rouge1"].fmeasure
rouge_scores["rouge2"] += score['rouge2'].fmeasure rouge_scores["rouge2"] += score["rouge2"].fmeasure
rouge_scores["rougeL"] += score['rougeL'].fmeasure rouge_scores["rougeL"] += score["rougeL"].fmeasure
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds) rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds) rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
...@@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: ...@@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
for pred in preds: for pred in preds:
if language == "cn": if language == "cn":
pred_seg_list = ' '.join(jieba.cut(pred)).split() pred_seg_list = " ".join(jieba.cut(pred)).split()
count_segs = len(pred_seg_list) count_segs = len(pred_seg_list)
unique_segs = set(pred_seg_list) unique_segs = set(pred_seg_list)
count_unique_chars = len(unique_segs) count_unique_chars = len(unique_segs)
...@@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: ...@@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
split_pred = preprocessing_text(pred).split() split_pred = preprocessing_text(pred).split()
for n in range(0, 3): for n in range(0, 3):
for i in range(0, len(split_pred) - n): for i in range(0, len(split_pred) - n):
ngram = ' '.join(split_pred[i:i + n + 1]) ngram = " ".join(split_pred[i : i + n + 1])
unique_ngram[n].add(ngram) unique_ngram[n].add(ngram)
all_ngram_count[n] += 1 all_ngram_count[n] += 1
...@@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language ...@@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language
for pred, target in zip(preds, targets): for pred, target in zip(preds, targets):
if language == "cn": if language == "cn":
pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()] pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()]
target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()] target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()]
elif language == "en": elif language == "en":
pred_list = [char for char in preprocessing_text(pred).split()] pred_list = [char for char in preprocessing_text(pred).split()]
target_list = [char for char in preprocessing_text(target).split()] target_list = [char for char in preprocessing_text(target).split()]
......
...@@ -7,6 +7,9 @@ from .utils import ( ...@@ -7,6 +7,9 @@ from .utils import (
) )
__all__ = [ __all__ = [
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results', "get_evaluator",
'analyze_unieval_results' "convert_data_to_unieval_format",
"calculate_average_score",
"save_unieval_results",
"analyze_unieval_results",
] ]
...@@ -28,29 +28,29 @@ from .utils import add_question ...@@ -28,29 +28,29 @@ from .utils import add_question
class SumEvaluator: class SumEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): """Set up evaluator for text summarization"""
""" Set up evaluator for text summarization """
self.scorer = UniEvaluator( self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length, max_length=max_length,
device=device, device=device,
cache_dir=cache_dir) cache_dir=cache_dir,
self.task = 'summarization' )
self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance'] self.task = "summarization"
self.dimensions = ["coherence", "consistency", "fluency", "relevance"]
def evaluate(self, data, category, dims=None, overall=True): def evaluate(self, data, category, dims=None, overall=True):
""" """
Get the scores of all the given dimensions Get the scores of all the given dimensions
category: The category to be evaluated. category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
four dimensions: coherence, consistency, fluency, relevance. four dimensions: coherence, consistency, fluency, relevance.
overall: indicates whether the overall score is to be calculated. overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions. dimensions. The default here is the average score of all the given dimensions.
""" """
n_data = len(data) n_data = len(data)
eval_scores = [{} for _ in range(n_data)] eval_scores = [{} for _ in range(n_data)]
...@@ -63,12 +63,12 @@ class SumEvaluator: ...@@ -63,12 +63,12 @@ class SumEvaluator:
for dim in eval_dims: for dim in eval_dims:
# Calculate average sentence-level scores for 'consistency' and 'fluency' # Calculate average sentence-level scores for 'consistency' and 'fluency'
if dim == 'consistency' or dim == 'fluency': if dim == "consistency" or dim == "fluency":
src_list, output_list = [], [] src_list, output_list = [], []
n_sents = [] # the number of sentences in each generated summary n_sents = [] # the number of sentences in each generated summary
for i in range(n_data): for i in range(n_data):
source = data[i]['source'] source = data[i]["source"]
system_outputs = sent_tokenize(data[i]['system_output']) system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs)) n_sents.append(len(system_outputs))
for j in range(len(system_outputs)): for j in range(len(system_outputs)):
src_list.append(source) src_list.append(source)
...@@ -81,24 +81,26 @@ class SumEvaluator: ...@@ -81,24 +81,26 @@ class SumEvaluator:
score = [] score = []
for cur_n_sent in n_sents: for cur_n_sent in n_sents:
# prevent denominator from being 0 # prevent denominator from being 0
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
start_idx += cur_n_sent start_idx += cur_n_sent
# Calculate summary-level score for 'coherence' and 'relevance' # Calculate summary-level score for 'coherence' and 'relevance'
elif dim == 'coherence' or dim == 'relevance': elif dim == "coherence" or dim == "relevance":
src_list, output_list, ref_list = [], [], [] src_list, output_list, ref_list = [], [], []
for i in range(n_data): for i in range(n_data):
src_list.append(data[i]['source']) src_list.append(data[i]["source"])
output_list.append(data[i]['system_output']) output_list.append(data[i]["system_output"])
if dim == 'relevance': if dim == "relevance":
ref_list.append(data[i]['reference']) ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task) input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim) score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization # Please customize other dimensions here for summarization
else: else:
raise NotImplementedError('The input format for this dimension is still undefined. \ raise NotImplementedError(
Please customize it first.') "The input format for this dimension is still undefined. \
Please customize it first."
)
for i in range(n_data): for i in range(n_data):
eval_scores[i][dim] = score[i] eval_scores[i][dim] = score[i]
...@@ -106,35 +108,35 @@ class SumEvaluator: ...@@ -106,35 +108,35 @@ class SumEvaluator:
# Customize your overall score here. # Customize your overall score here.
if overall == True: if overall == True:
for i in range(n_data): for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores return eval_scores
class DialogEvaluator: class DialogEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): """Set up evaluator for dialogues"""
""" Set up evaluator for dialogues """
self.scorer = UniEvaluator( self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path, model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path,
max_length=max_length, max_length=max_length,
device=device, device=device,
cache_dir=cache_dir) cache_dir=cache_dir,
self.task = 'dialogue' )
self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability'] self.task = "dialogue"
self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"]
def evaluate(self, data, category, dims=None, overall=True): def evaluate(self, data, category, dims=None, overall=True):
""" """
Get the scores of all the given dimensions Get the scores of all the given dimensions
category: The category to be evaluated. category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
five dimensions: naturalness, coherence, engagingness, groundedness and understandability. five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
overall: indicates whether the overall score is to be calculated. overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions. dimensions. The default here is the average score of all the given dimensions.
""" """
n_data = len(data) n_data = len(data)
eval_scores = [{} for _ in range(n_data)] eval_scores = [{} for _ in range(n_data)]
...@@ -147,50 +149,48 @@ class DialogEvaluator: ...@@ -147,50 +149,48 @@ class DialogEvaluator:
for dim in eval_dims: for dim in eval_dims:
# Calculate summation score for 'engagingness' # Calculate summation score for 'engagingness'
if dim == 'engagingness': if dim == "engagingness":
src_list, output_list, context_list = [], [], [] src_list, output_list, context_list = [], [], []
n_sents = [] # the number of sentences in each generated response n_sents = [] # the number of sentences in each generated response
for i in range(n_data): for i in range(n_data):
source = data[i]['source'] source = data[i]["source"]
context = data[i]['context'] context = data[i]["context"]
system_outputs = sent_tokenize(data[i]['system_output']) system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs)) n_sents.append(len(system_outputs))
for j in range(len(system_outputs)): for j in range(len(system_outputs)):
src_list.append(source) src_list.append(source)
context_list.append(context) context_list.append(context)
output_list.append(system_outputs[j]) output_list.append(system_outputs[j])
input_list = add_question(dimension=dim, input_list = add_question(
output=output_list, dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
src=src_list, )
context=context_list,
task=self.task)
sent_score = self.scorer.score(input_list, self.task, category, dim) sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get the summation score for each sample # Get the summation score for each sample
start_idx = 0 start_idx = 0
score = [] score = []
for cur_n_sent in n_sents: for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent])) score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]))
start_idx += cur_n_sent start_idx += cur_n_sent
# Calculate turn-level score for other dimensions # Calculate turn-level score for other dimensions
elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']: elif dim in ["naturalness", "coherence", "groundedness", "understandability"]:
src_list, output_list, context_list = [], [], [] src_list, output_list, context_list = [], [], []
for i in range(n_data): for i in range(n_data):
src_list.append(data[i]['source']) src_list.append(data[i]["source"])
output_list.append(data[i]['system_output']) output_list.append(data[i]["system_output"])
context_list.append(data[i]['context']) context_list.append(data[i]["context"])
input_list = add_question(dimension=dim, input_list = add_question(
output=output_list, dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
src=src_list, )
context=context_list,
task=self.task)
score = self.scorer.score(input_list, self.task, category, dim) score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization # Please customize other dimensions here for summarization
else: else:
raise NotImplementedError('The input format for this dimension is still undefined. \ raise NotImplementedError(
Please customize it first.') "The input format for this dimension is still undefined. \
Please customize it first."
)
for i in range(n_data): for i in range(n_data):
eval_scores[i][dim] = score[i] eval_scores[i][dim] = score[i]
...@@ -198,35 +198,35 @@ class DialogEvaluator: ...@@ -198,35 +198,35 @@ class DialogEvaluator:
# Customize your overall score here. # Customize your overall score here.
if overall == True: if overall == True:
for i in range(n_data): for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores return eval_scores
class D2tEvaluator: class D2tEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): """Set up evaluator for data-to-text"""
""" Set up evaluator for data-to-text """
self.scorer = UniEvaluator( self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length, max_length=max_length,
device=device, device=device,
cache_dir=cache_dir) cache_dir=cache_dir,
self.task = 'data2text' )
self.dimensions = ['naturalness', 'informativeness'] self.task = "data2text"
self.dimensions = ["naturalness", "informativeness"]
def evaluate(self, data, category, dims=None, overall=True): def evaluate(self, data, category, dims=None, overall=True):
""" """
Get the scores of all the given dimensions Get the scores of all the given dimensions
category: The category to be evaluated. category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
two dimensions: naturalness and informativeness. two dimensions: naturalness and informativeness.
overall: indicates whether the overall score is to be calculated. overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions. dimensions. The default here is the average score of all the given dimensions.
""" """
n_data = len(data) n_data = len(data)
eval_scores = [{} for _ in range(n_data)] eval_scores = [{} for _ in range(n_data)]
...@@ -240,8 +240,8 @@ class D2tEvaluator: ...@@ -240,8 +240,8 @@ class D2tEvaluator:
for dim in eval_dims: for dim in eval_dims:
output_list, ref_list = [], [] output_list, ref_list = [], []
for i in range(n_data): for i in range(n_data):
output_list.append(data[i]['system_output']) output_list.append(data[i]["system_output"])
ref_list.append(data[i]['reference']) ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task) input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim) score = self.scorer.score(input_list, self.task, category, dim)
...@@ -252,38 +252,38 @@ class D2tEvaluator: ...@@ -252,38 +252,38 @@ class D2tEvaluator:
# Customize your overall score here. # Customize your overall score here.
if overall == True: if overall == True:
for i in range(n_data): for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores return eval_scores
class FactEvaluator: class FactEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): """Set up evaluator for factual consistency detection"""
""" Set up evaluator for factual consistency detection """
self.scorer = UniEvaluator( self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path, model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path,
max_length=max_length, max_length=max_length,
device=device, device=device,
cache_dir=cache_dir) cache_dir=cache_dir,
self.task = 'fact' )
self.dim = 'consistency' self.task = "fact"
self.dim = "consistency"
def evaluate(self, data, category): def evaluate(self, data, category):
""" """
Get the factual consistency score (only 1 dimension for this task) Get the factual consistency score (only 1 dimension for this task)
category: The category to be evaluated. category: The category to be evaluated.
""" """
n_data = len(data) n_data = len(data)
eval_scores = [{} for _ in range(n_data)] eval_scores = [{} for _ in range(n_data)]
# Calculate average sentence-level scores for factual consistency # Calculate average sentence-level scores for factual consistency
src_list, output_list = [], [] src_list, output_list = [], []
n_sents = [] # the number of sentences in the claim n_sents = [] # the number of sentences in the claim
for i in range(n_data): for i in range(n_data):
source = data[i]['source'] source = data[i]["source"]
system_outputs = sent_tokenize(data[i]['system_output']) system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs)) n_sents.append(len(system_outputs))
for j in range(len(system_outputs)): for j in range(len(system_outputs)):
src_list.append(source) src_list.append(source)
...@@ -295,7 +295,7 @@ class FactEvaluator: ...@@ -295,7 +295,7 @@ class FactEvaluator:
start_idx = 0 start_idx = 0
score = [] score = []
for cur_n_sent in n_sents: for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent) score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent)
start_idx += cur_n_sent start_idx += cur_n_sent
for i in range(n_data): for i in range(n_data):
...@@ -304,28 +304,26 @@ class FactEvaluator: ...@@ -304,28 +304,26 @@ class FactEvaluator:
return eval_scores return eval_scores
def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None): def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None):
assert task in ['summarization', 'dialogue', 'data2text', 'fact'] assert task in ["summarization", "dialogue", "data2text", "fact"]
if task == 'summarization': if task == "summarization":
return SumEvaluator(model_name_or_path=model_name_or_path, return SumEvaluator(
max_length=max_length, model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
device=device, )
cache_dir=cache_dir) elif task == "dialogue":
elif task == 'dialogue': return DialogEvaluator(
return DialogEvaluator(model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
max_length=max_length, )
device=device, elif task == "data2text":
cache_dir=cache_dir) return D2tEvaluator(
elif task == 'data2text': model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
return D2tEvaluator(model_name_or_path=model_name_or_path, )
max_length=max_length, elif task == "fact":
device=device, return FactEvaluator(
cache_dir=cache_dir) model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
elif task == 'fact': )
return FactEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
else: else:
raise NotImplementedError('Other tasks are not implemented, \ raise NotImplementedError(
please customize specific tasks here.') "Other tasks are not implemented, \
please customize specific tasks here."
)
...@@ -27,9 +27,8 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer ...@@ -27,9 +27,8 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
class UniEvaluator: class UniEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): """Set up model"""
""" Set up model """
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
...@@ -47,8 +46,8 @@ class UniEvaluator: ...@@ -47,8 +46,8 @@ class UniEvaluator:
def score(self, inputs, task, category, dim, batch_size=8): def score(self, inputs, task, category, dim, batch_size=8):
""" """
Get scores for the given samples. Get scores for the given samples.
final_score = postive_score / (postive_score + negative_score) final_score = postive_score / (postive_score + negative_score)
""" """
# The implementation of "forward" in T5 still requires decoder_input_ids. # The implementation of "forward" in T5 still requires decoder_input_ids.
...@@ -58,31 +57,27 @@ class UniEvaluator: ...@@ -58,31 +57,27 @@ class UniEvaluator:
pos_score_list, neg_score_list = [], [] pos_score_list, neg_score_list = [], []
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "): for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
src_list = inputs[i:i + batch_size] src_list = inputs[i : i + batch_size]
tgt_list = tgts[i:i + batch_size] tgt_list = tgts[i : i + batch_size]
try: try:
with torch.no_grad(): with torch.no_grad():
encoded_src = self.tokenizer(src_list, encoded_src = self.tokenizer(
max_length=self.max_length, src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
truncation=True, )
padding=True, encoded_tgt = self.tokenizer(
return_tensors='pt') tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
encoded_tgt = self.tokenizer(tgt_list, )
max_length=self.max_length,
truncation=True, src_tokens = encoded_src["input_ids"].to(self.device)
padding=True, src_mask = encoded_src["attention_mask"].to(self.device)
return_tensors='pt')
tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1)
src_tokens = encoded_src['input_ids'].to(self.device)
src_mask = encoded_src['attention_mask'].to(self.device)
tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens) output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
logits = output.logits.view(-1, self.model.config.vocab_size) logits = output.logits.view(-1, self.model.config.vocab_size)
pos_score = self.softmax(logits)[:, self.pos_id] # Yes pos_score = self.softmax(logits)[:, self.pos_id] # Yes
neg_score = self.softmax(logits)[:, self.neg_id] # No neg_score = self.softmax(logits)[:, self.neg_id] # No
cur_pos_score = [x.item() for x in pos_score] cur_pos_score = [x.item() for x in pos_score]
cur_neg_score = [x.item() for x in neg_score] cur_neg_score = [x.item() for x in neg_score]
...@@ -90,8 +85,8 @@ class UniEvaluator: ...@@ -90,8 +85,8 @@ class UniEvaluator:
neg_score_list += cur_neg_score neg_score_list += cur_neg_score
except RuntimeError: except RuntimeError:
print(f'source: {src_list}') print(f"source: {src_list}")
print(f'target: {tgt_list}') print(f"target: {tgt_list}")
exit(0) exit(0)
score_list = [] score_list = []
......
...@@ -31,105 +31,142 @@ import tqdm ...@@ -31,105 +31,142 @@ import tqdm
def add_question(dimension, output, src=None, ref=None, context=None, task=None): def add_question(dimension, output, src=None, ref=None, context=None, task=None):
""" """
Add questions to generate input in Bool-QA format for UniEval. Add questions to generate input in Bool-QA format for UniEval.
dimension: specific dimension to be evaluated dimension: specific dimension to be evaluated
src: source input for different NLG tasks. For example, source document for summarization src: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation. and dialogue history for dialogue response generation.
output: output text generated by the models output: output text generated by the models
ref: human-annotated groundtruth ref: human-annotated groundtruth
context: the context needed to evaluate several specific dimension. For example, context: the context needed to evaluate several specific dimension. For example,
additional factual information when evaluating engagingness and groundedness in dialogues. additional factual information when evaluating engagingness and groundedness in dialogues.
""" """
input_with_question = [] input_with_question = []
for i in range(len(output)): for i in range(len(output)):
# For summarization # For summarization
if task == 'summarization': if task == "summarization":
if dimension == 'fluency': if dimension == "fluency":
cur_input = 'question: Is this a fluent paragraph? </s> paragraph: ' + output[i] cur_input = "question: Is this a fluent paragraph? </s> paragraph: " + output[i]
elif dimension == 'coherence': elif dimension == "coherence":
cur_input = 'question: Is this a coherent summary to the document? </s> summary: ' + output[ cur_input = (
i] + ' </s> document: ' + src[i] "question: Is this a coherent summary to the document? </s> summary: "
elif dimension == 'consistency': + output[i]
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[ + " </s> document: "
i] + ' </s> document: ' + src[i] + src[i]
elif dimension == 'relevance': )
cur_input = 'question: Is this summary relevant to the reference? </s> summary: ' + output[ elif dimension == "consistency":
i] + ' </s> reference: ' + ref[i] cur_input = (
"question: Is this claim consistent with the document? </s> claim: "
+ output[i]
+ " </s> document: "
+ src[i]
)
elif dimension == "relevance":
cur_input = (
"question: Is this summary relevant to the reference? </s> summary: "
+ output[i]
+ " </s> reference: "
+ ref[i]
)
else: else:
raise NotImplementedError( raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.') "The input format for this dimension is still undefined. Please customize it first."
)
# For dialogues # For dialogues
elif task == 'dialogue': elif task == "dialogue":
if dimension == 'naturalness': if dimension == "naturalness":
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + output[i] cur_input = "question: Is this a natural response in the dialogue? </s> response: " + output[i]
elif dimension == 'coherence': elif dimension == "coherence":
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: '\ cur_input = (
+ output[i] + ' </s> dialogue history: ' + src[i] "question: Is this a coherent response given the dialogue history? </s> response: "
elif dimension == 'engagingness': + output[i]
cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: '\ + " </s> dialogue history: "
+ output[i] + ' </s> dialogue history: ' + src[i] + ' </s> fact: ' + context[i] + src[i]
elif dimension == 'groundedness': )
cur_input = 'question: Is this response consistent with knowledge in the fact? </s> response: '\ elif dimension == "engagingness":
+ output[i] + ' </s> fact: ' + context[i] cur_input = (
elif dimension == 'understandability': "question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: "
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + output[i] + output[i]
+ " </s> dialogue history: "
+ src[i]
+ " </s> fact: "
+ context[i]
)
elif dimension == "groundedness":
cur_input = (
"question: Is this response consistent with knowledge in the fact? </s> response: "
+ output[i]
+ " </s> fact: "
+ context[i]
)
elif dimension == "understandability":
cur_input = "question: Is this an understandable response in the dialogue? </s> response: " + output[i]
else: else:
raise NotImplementedError( raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.') "The input format for this dimension is still undefined. Please customize it first."
)
# For data-to-text # For data-to-text
elif task == 'data2text': elif task == "data2text":
if dimension == 'naturalness': if dimension == "naturalness":
cur_input = 'question: Is this a fluent utterance? </s> utterance: ' + output[i] cur_input = "question: Is this a fluent utterance? </s> utterance: " + output[i]
elif dimension == 'informativeness': elif dimension == "informativeness":
cur_input = 'question: Is this sentence informative according to the reference? </s> sentence: '\ cur_input = (
+ output[i] + ' </s> reference: ' + ref[i] "question: Is this sentence informative according to the reference? </s> sentence: "
+ output[i]
+ " </s> reference: "
+ ref[i]
)
else: else:
raise NotImplementedError( raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.') "The input format for this dimension is still undefined. Please customize it first."
)
# For factual consistency detection # For factual consistency detection
elif task == 'fact': elif task == "fact":
if dimension == 'consistency': if dimension == "consistency":
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[ cur_input = (
i] + ' </s> document: ' + src[i] "question: Is this claim consistent with the document? </s> claim: "
+ output[i]
+ " </s> document: "
+ src[i]
)
else: else:
raise NotImplementedError('No other dimensions for the factual consistency detection task.') raise NotImplementedError("No other dimensions for the factual consistency detection task.")
# For new customized tasks # For new customized tasks
else: else:
raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.') raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.")
input_with_question.append(cur_input) input_with_question.append(cur_input)
return input_with_question return input_with_question
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None): def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
""" """
Convert the data into the unieval's format. Convert the data into the unieval's format.
output_list: a list of model output output_list: a list of model output
src_list: source input for different NLG tasks. For example, source document for summarization src_list: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation and dialogue history for dialogue response generation
ref_list: human-annotated groundtruth ref_list: human-annotated groundtruth
""" """
json_data = [] json_data = []
for i in range(len(output_list)): for i in range(len(output_list)):
cur = {} cur = {}
cur['system_output'] = output_list[i] cur["system_output"] = output_list[i]
if src_list is not None: if src_list is not None:
cur['source'] = src_list[i] cur["source"] = src_list[i]
if ref_list is not None: if ref_list is not None:
cur['reference'] = ref_list[i] cur["reference"] = ref_list[i]
cur['context'] = "" cur["context"] = ""
json_data.append(cur) json_data.append(cur)
return json_data return json_data
def calculate_average_score(scores): def calculate_average_score(scores):
""" """
Calculate average scores for different metrics Calculate average scores for different metrics
scores: a list of scores for different metrics for each answer scores: a list of scores for different metrics for each answer
""" """
metrics = {metric: 0 for metric in scores[0]} metrics = {metric: 0 for metric in scores[0]}
...@@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None: ...@@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv")) frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
for metric in tqdm.tqdm( for metric in tqdm.tqdm(
frame_per_metric.keys(), frame_per_metric.keys(),
desc=f"UniEval metrics: ", desc=f"UniEval metrics: ",
total=len(frame_per_metric.keys()), total=len(frame_per_metric.keys()),
): ):
data = pd.DataFrame(frame_per_metric[metric]) data = pd.DataFrame(frame_per_metric[metric])
......
import io import io
import json import json
import os import os
import re
import string import string
from typing import Dict from typing import Dict
...@@ -55,7 +54,7 @@ def jload(f, mode="r"): ...@@ -55,7 +54,7 @@ def jload(f, mode="r"):
def get_json_list(file_path): def get_json_list(file_path):
with open(file_path, 'r') as f: with open(file_path, "r") as f:
json_list = [] json_list = []
for line in f: for line in f:
json_list.append(json.loads(line)) json_list.append(json.loads(line))
...@@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None: ...@@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv")) frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
for metric in tqdm.tqdm( for metric in tqdm.tqdm(
frame_per_metric.keys(), frame_per_metric.keys(),
desc=f"automatic metrics: ", desc=f"automatic metrics: ",
total=len(frame_per_metric.keys()), total=len(frame_per_metric.keys()),
): ):
data = pd.DataFrame(frame_per_metric[metric]) data = pd.DataFrame(frame_per_metric[metric])
......
...@@ -3,7 +3,6 @@ import json ...@@ -3,7 +3,6 @@ import json
from typing import Dict, Sequence from typing import Dict, Sequence
import torch import torch
from datasets import load_dataset
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i ...@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest", padding="longest",
max_length=max_length, max_length=max_length,
truncation=True, truncation=True,
) for text in strings )
for text in strings
] ]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [ input_ids_lens = labels_lens = [
...@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo ...@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset): class EasySupervisedDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__() super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f: with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines() all_lines = f.readlines()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], [] sources, targets = [], []
for line in all_lines: for line in all_lines:
if "回答:" in line: if "回答:" in line:
sep_index = line.index("回答:") sep_index = line.index("回答:")
sources.append(line[:sep_index + 3]) sources.append(line[: sep_index + 3])
targets.append(line[sep_index + 3:] + tokenizer.eos_token) targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else: else:
sources.append(line) sources.append(line)
targets.append("" + tokenizer.eos_token) targets.append("" + tokenizer.eos_token)
...@@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset): ...@@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
class EasyPromptsDataset(Dataset): class EasyPromptsDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__() super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f: with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines() all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [ self.prompts = [
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) "input_ids"
]
.to(torch.cuda.current_device())
.squeeze(0)
for line in tqdm(all_lines) for line in tqdm(all_lines)
] ]
self.data_file = data_file self.data_file = data_file
...@@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset): ...@@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset(Dataset): class EasyRewardDataset(Dataset):
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__() super(EasyRewardDataset, self).__init__()
self.chosen = [] self.chosen = []
...@@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset): ...@@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
else: else:
self.end_token = special_token self.end_token = special_token
print(self.end_token) print(self.end_token)
#read all lines in the train_file to a list # read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f: with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines() all_lines = f.readlines()
for line in tqdm(all_lines): for line in tqdm(all_lines):
data = json.loads(line) data = json.loads(line)
prompt = "提问:" + data['prompt'] + " 回答:" prompt = "提问:" + data["prompt"] + " 回答:"
chosen = prompt + data['chosen'] + self.end_token chosen = prompt + data["chosen"] + self.end_token
chosen_token = tokenizer(chosen, chosen_token = tokenizer(
max_length=max_length, chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
padding="max_length", )
truncation=True, self.chosen.append(
return_tensors="pt") {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
self.chosen.append({ )
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask'] reject = prompt + data["rejected"] + self.end_token
}) reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
reject = prompt + data['rejected'] + self.end_token )
reject_token = tokenizer(reject, self.reject.append(
max_length=max_length, {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
padding="max_length", )
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
def __len__(self): def __len__(self):
length = len(self.chosen) length = len(self.chosen)
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ return (
"input_ids"], self.reject[idx]["attention_mask"] self.chosen[idx]["input_ids"],
self.chosen[idx]["attention_mask"],
#python representation of the object and the string representation of the object self.reject[idx]["input_ids"],
self.reject[idx]["attention_mask"],
)
# python representation of the object and the string representation of the object
def __repr__(self): def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
...@@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset): ...@@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
''' """
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False. If individual lines are not related, just set is_group_texts to False.
''' """
class EasySFTDataset(Dataset): class EasySFTDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__() super().__init__()
#read the data_file line by line # read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f: with open(data_file, "r", encoding="UTF-8") as f:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list # encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = [] raw_input_ids = []
for line in f: for line in f:
encoded_ids = tokenizer.encode(line) encoded_ids = tokenizer.encode(line)
#if the encoded_ids is longer than max_length, then split it into several parts # if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length: if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length): for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i:i + max_length]) raw_input_ids.append(encoded_ids[i : i + max_length])
else: else:
raw_input_ids.append(encoded_ids) raw_input_ids.append(encoded_ids)
...@@ -196,12 +193,13 @@ class EasySFTDataset(Dataset): ...@@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
if is_group_texts: if is_group_texts:
for input_ids in raw_input_ids: for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length: if len(current_input_ids) + len(input_ids) > max_length:
#pad the current_input_ids to max_length with tokenizer.pad_token_id # pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids) padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append( attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
current_input_ids = [] current_input_ids = []
else: else:
current_input_ids.extend(input_ids) current_input_ids.extend(input_ids)
...@@ -210,14 +208,16 @@ class EasySFTDataset(Dataset): ...@@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append( attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
else: else:
#just append the raw_input_ids to max_length # just append the raw_input_ids to max_length
for input_ids in raw_input_ids: for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids) padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length) input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append( attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids) self.labels = copy.deepcopy(self.input_ids)
...@@ -227,14 +227,14 @@ class EasySFTDataset(Dataset): ...@@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.input_ids) return len(self.input_ids)
#get item from dataset # get item from dataset
def __getitem__(self, idx): def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
#generate the dataset description to be printed by print in python # generate the dataset description to be printed by print in python
def __repr__(self): def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
#generate the dataset description to be printed by print in python # generate the dataset description to be printed by print in python
def __str__(self): def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from coati.models.generation import generate from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits, masked_mean from coati.models.utils import log_probs_from_logits
from peft import PeftModel from peft import PeftModel
from torch.nn.modules import Module from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM from transformers import BloomConfig, BloomForCausalLM
...@@ -24,38 +24,33 @@ class Actor(Module): ...@@ -24,38 +24,33 @@ class Actor(Module):
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs) sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None) pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None: if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask: if not return_action_mask:
return sequences, attention_mask, None return sequences, attention_mask, None
input_len = input_ids.size(1) input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None) eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None: if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool) action_mask = torch.ones_like(sequences, dtype=torch.bool)
else: else:
# left padding may be applied, only mask action # left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:] action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
def forward(self, def forward(
sequences: torch.LongTensor, self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
num_actions: int, ) -> torch.Tensor:
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Returns action log probs"""
"""Returns action log probs
"""
output = self.model(sequences, attention_mask=attention_mask) output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits'] logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:] return log_probs[:, -num_actions:]
...@@ -75,11 +70,13 @@ class BLOOMActor(Actor): ...@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: str = None, self,
config: Optional[BloomConfig] = None, pretrained: str = None,
checkpoint: bool = False, config: Optional[BloomConfig] = None,
lora_path: str = None) -> None: checkpoint: bool = False,
lora_path: str = None,
) -> None:
if pretrained is not None: if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained) model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
import argparse import argparse
import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic from coati.models.bloom import BLOOMRM, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor from easy_models import BLOOMActor
from peft import PeftModel
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
def main(args): def main(args):
# configure strategy # configure strategy
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None: if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location='cpu') state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model # configure model
if args.model == 'bloom': if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain) # initial_model = BLOOMActor(pretrained=args.pretrain)
print('Using peft lora to load Bloom model as initial_model') print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as initial_model (Done)') print("Using peft lora to load Bloom model as initial_model (Done)")
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
...@@ -49,59 +47,59 @@ def main(args): ...@@ -49,59 +47,59 @@ def main(args):
else: else:
rm_model_name = args.rm_model rm_model_name = args.rm_model
if rm_model_name == 'gpt2': if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain) reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom': elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain) print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt': elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain) reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama': elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain) reward_model = LlamaRM(pretrained=args.rm_pretrain)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None: if args.rm_path is not None:
print('Loading reward model from', args.rm_path) print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict) reward_model.load_state_dict(state_dict)
if args.strategy != 'colossalai_gemini': if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device()) initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'bloom': if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print('Using peft lora to load Bloom model as Actor') print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as Actor (Done)') print("Using peft lora to load Bloom model as Actor (Done)")
else: else:
raise ValueError(f'Unsupported actor model "{args.model}"') raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == 'gpt2': if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom': elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ") print("load bloom critic (Done) ")
elif rm_model_name == 'opt': elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama': elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None: if args.rm_path is not None:
print('Loading reward model from', args.rm_path) print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict) critic.load_state_dict(state_dict)
del state_dict del state_dict
if args.strategy != 'colossalai_gemini': if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device()) critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7) actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else: else:
...@@ -109,18 +107,18 @@ def main(args): ...@@ -109,18 +107,18 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7) critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
tokenizer.eos_token = '<\s>' tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -132,26 +130,27 @@ def main(args): ...@@ -132,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
prompt_sampler = None prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset, prompt_dataloader = DataLoader(
shuffle=(prompt_sampler is None), prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
sampler=prompt_sampler, )
batch_size=args.train_batch_size)
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else: else:
pretrain_sampler = None pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset, pretrain_dataloader = DataLoader(
shuffle=(pretrain_sampler is None), pretrain_dataset,
sampler=pretrain_sampler, shuffle=(pretrain_sampler is None),
batch_size=args.ptx_batch_size, sampler=pretrain_sampler,
collate_fn=data_collator) batch_size=args.ptx_batch_size,
collate_fn=data_collator,
)
def tokenize_fn(texts): def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length # MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps # Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
...@@ -178,45 +177,46 @@ def main(args): ...@@ -178,45 +177,46 @@ def main(args):
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
) )
trainer.fit(prompt_dataloader=prompt_dataloader, trainer.fit(
pretrain_dataloader=pretrain_dataloader, prompt_dataloader=prompt_dataloader,
num_episodes=args.num_episodes, pretrain_dataloader=pretrain_dataloader,
num_update_steps=args.num_update_steps, num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps) num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps,
)
# save model checkpoint after fitting # save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim, strategy.save_optimizer(
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
only_rank0=False) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument('--strategy', parser.add_argument(
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
default='ddp', )
help='strategy to use') parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--sft_lora_path", type=str, default=None)
parser.add_argument('--sft_lora_path', type=str, default=None) parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument('--rm_path', type=str, default=None) parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None) parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument('--num_episodes', type=int, default=10) parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument('--num_collect_steps', type=int, default=10) parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument('--num_update_steps', type=int, default=5) parser.add_argument("--train_batch_size", type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=2) parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument('--ptx_batch_size', type=int, default=1) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument('--kl_coef', type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.9)
parser.add_argument('--ptx_coef', type=float, default=0.9)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment