Commit 2369eb2b authored by chenych's avatar chenych
Browse files

update

parent ac9d2b05
...@@ -41,14 +41,15 @@ EasyR1基于 **[HybirdEngine](https://arxiv.org/abs/2409.19256)** 和最新发 ...@@ -41,14 +41,15 @@ EasyR1基于 **[HybirdEngine](https://arxiv.org/abs/2409.19256)** 和最新发
> [!NOTE] > [!NOTE]
> 使用 `worker.actor.fsdp.torch_dtype=bf16` 和 `worker.actor.optim.strategy=adamw_bf16`参数确保使用 bf16 类型训练。 > 使用 `worker.actor.fsdp.torch_dtype=bf16` 和 `worker.actor.optim.strategy=adamw_bf16`参数确保使用 bf16 类型训练。
> >
> 我们正在努力减少RL训练中的VRAM,LoRA支持将在下一次更新中集成。 > 训练需要使用到wandb,环境安装结束后,需要先登录wandb。
## 教程: 只需三步,在 [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) 数据集上基于GRPO算法训练Qwen2.5-VL。 ## 教程: 只需三步,在 [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) 数据集上基于GRPO算法训练Qwen2.5-VL。
![image](assets/qwen2_5_vl_7b_geo.png) ![image](assets/qwen2_5_vl_7b_geo.png)
### 环境准备 ### 环境准备
`-v 路径``docker_name``imageID`根据实际情况修改 `-v 路径``docker_name``imageID`根据实际情况修改
#### Docker(方法一) #### Docker(方法一)
......
assets/wechat.jpg

164 KB | W: | H:

assets/wechat.jpg

111 KB | W: | H:

assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
  • 2-up
  • Swipe
  • Onion skin
...@@ -2,18 +2,13 @@ set -x ...@@ -2,18 +2,13 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant
first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning
process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think><answer> answer here </answer>"""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=BUAADreamer/clevr_count_70k@train \ data.train_files=BUAADreamer/clevr_count_70k@train \
data.val_files=BUAADreamer/clevr_count_70k@test \ data.val_files=BUAADreamer/clevr_count_70k@test \
data.format_prompt="${FORMAT_PROMPT}" \ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \ worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=r1v \ worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \ trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2 trainer.n_gpus_per_node=2
...@@ -2,18 +2,13 @@ set -x ...@@ -2,18 +2,13 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
FORMAT_PROMPT="""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant
first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning
process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think><answer> answer here </answer>"""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=leonardPKU/GEOQA_8K_R1V@train \ data.train_files=leonardPKU/GEOQA_8K_R1V@train \
data.val_files=leonardPKU/GEOQA_8K_R1V@test \ data.val_files=leonardPKU/GEOQA_8K_R1V@test \
data.format_prompt="${FORMAT_PROMPT}" \ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \ worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=r1v \ worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \ trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8 trainer.n_gpus_per_node=8
...@@ -8,10 +8,12 @@ data: ...@@ -8,10 +8,12 @@ data:
max_response_length: 2048 max_response_length: 2048
rollout_batch_size: 512 rollout_batch_size: 512
val_batch_size: -1 val_batch_size: -1
format_prompt: ./examples/format_prompt/math_format.jinja
shuffle: true shuffle: true
seed: 1 seed: 1
max_pixels: 4194304 max_pixels: 4194304
min_pixels: 262144 min_pixels: 262144
filter_overlong_prompts: true
algorithm: algorithm:
adv_estimator: grpo adv_estimator: grpo
...@@ -47,8 +49,9 @@ worker: ...@@ -47,8 +49,9 @@ worker:
offload_optimizer: true # true: more CPU memory; false: more GPU memory offload_optimizer: true # true: more CPU memory; false: more GPU memory
rollout: rollout:
temperature: 1.0
n: 5 n: 5
temperature: 1.0
top_p: 0.99
gpu_memory_utilization: 0.6 gpu_memory_utilization: 0.6
enforce_eager: false enforce_eager: false
enable_chunked_prefill: false enable_chunked_prefill: false
...@@ -68,8 +71,7 @@ worker: ...@@ -68,8 +71,7 @@ worker:
reward: reward:
reward_type: function reward_type: function
score_function: math score_function: ./examples/score_function/math.py:compute_score
skip_special_tokens: true
trainer: trainer:
total_episodes: 15 total_episodes: 15
......
{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
{{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-7B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-7B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/math12k@train \ data.train_files=hiyouga/math12k@train \
data.val_files=hiyouga/math12k@test \ data.val_files=hiyouga/math12k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
trainer.experiment_name=qwen2_5_7b_math_grpo \ trainer.experiment_name=qwen2_5_7b_math_grpo \
trainer.n_gpus_per_node=8 trainer.n_gpus_per_node=8
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-32B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-32B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/geometry3k@train \ data.train_files=hiyouga/geometry3k@train \
data.val_files=hiyouga/geometry3k@test \ data.val_files=hiyouga/geometry3k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.actor.micro_batch_size_per_device_for_update=1 \ worker.actor.micro_batch_size_per_device_for_update=1 \
worker.actor.micro_batch_size_per_device_for_experience=8 \ worker.actor.micro_batch_size_per_device_for_experience=8 \
......
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/geometry3k@train \ data.train_files=hiyouga/geometry3k@train \
data.val_files=hiyouga/geometry3k@test \ data.val_files=hiyouga/geometry3k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \ worker.rollout.tensor_parallel_size=1 \
trainer.experiment_name=qwen2_5_vl_3b_geo_grpo \ trainer.experiment_name=qwen2_5_vl_3b_geo_grpo \
......
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/geometry3k@train \ data.train_files=hiyouga/geometry3k@train \
data.val_files=hiyouga/geometry3k@test \ data.val_files=hiyouga/geometry3k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \ trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \
trainer.n_gpus_per_node=8 trainer.n_gpus_per_node=8
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/geometry3k@train \ data.train_files=hiyouga/geometry3k@train \
data.val_files=hiyouga/geometry3k@test \ data.val_files=hiyouga/geometry3k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
algorithm.adv_estimator=reinforce_plus_plus \ algorithm.adv_estimator=reinforce_plus_plus \
algorithm.use_kl_loss=false \ algorithm.use_kl_loss=false \
......
...@@ -2,14 +2,10 @@ set -x ...@@ -2,14 +2,10 @@ set -x
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
FORMAT_PROMPT="""You FIRST think about the reasoning process as an internal monologue and then provide the final answer.
The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}."""
python3 -m verl.trainer.main \ python3 -m verl.trainer.main \
config=examples/config.yaml \ config=examples/config.yaml \
data.train_files=hiyouga/geometry3k@train \ data.train_files=hiyouga/geometry3k@train \
data.val_files=hiyouga/geometry3k@test \ data.val_files=hiyouga/geometry3k@test \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \ trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \
trainer.logger=['console','swanlab'] \ trainer.logger=['console','swanlab'] \
......
# REMINDER: this script uses test data split and should ONLY be used for debugging. DO NOT use for training.
set -x
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
python3 -m verl.trainer.main \
config=examples/config.yaml \
data.train_files=hiyouga/journeybench-multi-image-vqa@train \
data.val_files=hiyouga/journeybench-multi-image-vqa@test \
data.rollout_batch_size=256 \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.limit_images=2 \
trainer.experiment_name=qwen2_5_vl_7b_multi_image \
trainer.n_gpus_per_node=8
working_dir: ./ working_dir: ./
excludes: ["/.git/"] excludes: ["/.git/"]
env_vars: env_vars:
TOKENIZERS_PARALLELISM: "true"
NCCL_DEBUG: "WARN"
VLLM_LOGGING_LEVEL: "INFO"
TORCH_NCCL_AVOID_RECORD_STREAMS: "1" TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"
...@@ -18,23 +18,23 @@ from typing import Dict ...@@ -18,23 +18,23 @@ from typing import Dict
from mathruler.grader import extract_boxed_content, grade_answer from mathruler.grader import extract_boxed_content, grade_answer
def math_format_reward(predict_str: str) -> float: def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL) pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str) format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0 return 1.0 if format_match else 0.0
def math_acc_reward(predict_str: str, ground_truth: str) -> float: def accuracy_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str) answer = extract_boxed_content(predict_str)
return 1.0 if grade_answer(answer, ground_truth) else 0.0 return 1.0 if grade_answer(answer, ground_truth) else 0.0
def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]: def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
format = math_format_reward(predict_str) format_score = format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth) accuracy_score = accuracy_reward(predict_str, ground_truth)
return { return {
"overall": 0.9 * accuracy + 0.1 * format, "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format, "format": format_score,
"accuracy": accuracy, "accuracy": accuracy_score,
} }
...@@ -18,18 +18,17 @@ from typing import Dict ...@@ -18,18 +18,17 @@ from typing import Dict
from mathruler.grader import grade_answer from mathruler.grader import grade_answer
def r1v_format_reward(predict_str: str) -> float: def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL) pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str) format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0 return 1.0 if format_match else 0.0
def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float: def accuracy_reward(predict_str: str, ground_truth: str) -> float:
try: try:
ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str) content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
given_answer = content_match.group(1).strip() if content_match else predict_str.strip() given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth): if grade_answer(given_answer, ground_truth.strip()):
return 1.0 return 1.0
except Exception: except Exception:
...@@ -38,11 +37,11 @@ def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float: ...@@ -38,11 +37,11 @@ def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
return 0.0 return 0.0
def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]: def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format = r1v_format_reward(predict_str) format_score = format_reward(predict_str)
accuracy = r1v_accuracy_reward(predict_str, ground_truth) accuracy_score = accuracy_reward(predict_str, ground_truth)
return { return {
"overall": 0.5 * accuracy + 0.5 * format, "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format, "format": format_score,
"accuracy": accuracy, "accuracy": accuracy_score,
} }
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.2.0.dev" __version__ = "0.3.0"
...@@ -136,8 +136,8 @@ def fold_batch_dim(data: "DataProto", new_batch_size: int): ...@@ -136,8 +136,8 @@ def fold_batch_dim(data: "DataProto", new_batch_size: int):
tensor = tensor.view(new_batch_size, -1) tensor = tensor.view(new_batch_size, -1)
tensor.auto_batch_size_(batch_dims=1) tensor.auto_batch_size_(batch_dims=1)
for key, val in non_tensor.items(): for key, value in non_tensor.items():
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) non_tensor[key] = np.reshape(value, newshape=(new_batch_size, -1, *value.shape[1:]))
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
...@@ -182,14 +182,14 @@ class DataProto: ...@@ -182,14 +182,14 @@ class DataProto:
if self.batch is not None: if self.batch is not None:
return self.batch.batch_size[0] return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0] pivot_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0] return self.non_tensor_batch[pivot_key].shape[0]
else: else:
return 0 return 0
def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]: def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]:
tensor_data = self.batch[item] tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} non_tensor_data = {key: value[item] for key, value in self.non_tensor_batch.items()}
return_type = DataProto if isinstance(item, slice) else DataProtoItem return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
...@@ -223,9 +223,10 @@ class DataProto: ...@@ -223,9 +223,10 @@ class DataProto:
def print_size(self, prefix: str = "") -> None: def print_size(self, prefix: str = "") -> None:
size_of_tensordict = 0 size_of_tensordict = 0
for tensor in self.batch.values(): if self.batch is not None:
if isinstance(tensor, torch.Tensor): for tensor in self.batch.values():
size_of_tensordict += tensor.element_size() * tensor.numel() if isinstance(tensor, torch.Tensor):
size_of_tensordict += tensor.element_size() * tensor.numel()
size_of_numpy_array = 0 size_of_numpy_array = 0
for value in self.non_tensor_batch.values(): for value in self.non_tensor_batch.values():
...@@ -249,8 +250,8 @@ class DataProto: ...@@ -249,8 +250,8 @@ class DataProto:
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty."
batch_size = self.batch.batch_size[0] batch_size = self.batch.batch_size[0]
for key, val in self.non_tensor_batch.items(): for key, value in self.non_tensor_batch.items():
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}." assert len(value) == batch_size, f"key {key} length {len(value)} is not equal to bsz {batch_size}."
@classmethod @classmethod
def from_single_dict( def from_single_dict(
...@@ -258,8 +259,7 @@ class DataProto: ...@@ -258,8 +259,7 @@ class DataProto:
data: Dict[str, Union[torch.Tensor, NDArray]], data: Dict[str, Union[torch.Tensor, NDArray]],
meta_info: Optional[Dict[str, Any]] = None, meta_info: Optional[Dict[str, Any]] = None,
) -> "DataProto": ) -> "DataProto":
tensors = {} tensors, non_tensors = {}, {}
non_tensors = {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
tensors[key] = value tensors[key] = value
...@@ -551,7 +551,7 @@ class DataProto: ...@@ -551,7 +551,7 @@ class DataProto:
""" """
indices_np = indices.detach().numpy() indices_np = indices.detach().numpy()
self.batch = self.batch[indices] self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} self.non_tensor_batch = {key: value[indices_np] for key, value in self.non_tensor_batch.items()}
def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto": def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto":
""" """
...@@ -666,9 +666,9 @@ def allgather_dict_tensors( ...@@ -666,9 +666,9 @@ def allgather_dict_tensors(
output = {} output = {}
sorted_keys = sorted(tensors_as_dict.keys()) sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys: for key in sorted_keys:
val = tensors_as_dict[key] value = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)] output[key] = [torch.empty_like(value) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False) torch.distributed.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim) output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict: if is_tensor_dict:
......
...@@ -28,13 +28,13 @@ class ResourcePool: ...@@ -28,13 +28,13 @@ class ResourcePool:
"""The resource pool with meta info such as world size.""" """The resource pool with meta info such as world size."""
def __init__( def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8 self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8
) -> None: ) -> None:
if process_on_nodes is None: if process_on_nodes is None:
process_on_nodes = [] process_on_nodes = []
self._store = process_on_nodes self._store = process_on_nodes
self.max_collocate_count = max_collocate_count self.max_colocate_count = max_colocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
def add_node(self, process_count): def add_node(self, process_count):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment