Commit f17c9671 authored by zzg_666's avatar zzg_666
Browse files

first commit

parents
Pipeline #3026 failed with stages
in 0 seconds
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import re
from PIL import Image
import numpy as np
import torch
from transformers import GenerationConfig
from transformers.generation import LogitsProcessorList
from .logits_processor import UnbatchedClassifierFreeGuidanceLogitsForVisualTokenWithDifferentialTopKProcessor
@torch.no_grad()
def generate(
cfg,
model,
tokenizer,
input_ids,
unconditional_ids,
full_unconditional_ids=None,
force_same_image_size=True,
):
if cfg.streaming:
raise NotImplementedError("Not supported streaming generation")
# yield from streaming_generate(cfg, model, tokenizer, input_ids, unconditional_ids, full_unconditional_ids)
else:
yield non_streaming_generate(cfg, model, tokenizer, input_ids, unconditional_ids, full_unconditional_ids, force_same_image_size)
def streaming_generate(
cfg,
model,
tokenizer,
input_ids,
unconditional_ids,
full_unconditional_ids=None,
):
pass
def non_streaming_generate(
cfg,
model,
tokenizer,
input_ids,
unconditional_ids,
full_unconditional_ids=None,
force_same_image_size=True,
):
input_ids_len = input_ids.shape[1]
logits_processor = LogitsProcessorList()
logits_processor.append(
build_logits_processor(
cfg,
unconditional_ids,
model,
tokenizer,
full_unconditional_ids,
force_same_image_size=force_same_image_size,
)
)
generation_config = GenerationConfig(
**cfg.sampling_params,
pad_token_id=cfg.special_token_ids["PAD"],
eos_token_id=cfg.special_token_ids["EOS"],
)
token_ids = model.generate(
input_ids,
generation_config,
logits_processor=logits_processor,
)
gen_token_ids = token_ids[:, input_ids_len:]
return gen_token_ids[0].detach().cpu().numpy()
def build_logits_processor(
cfg,
unconditional_ids,
model,
tokenizer,
full_unconditional_ids=None,
force_same_image_size=True,
):
logits_processor = UnbatchedClassifierFreeGuidanceLogitsForVisualTokenWithDifferentialTopKProcessor(
guidance_scale=cfg.classifier_free_guidance,
unconditional_ids=unconditional_ids,
full_unconditional_ids=full_unconditional_ids,
model=model,
tokenizer=tokenizer,
unconditional_type=cfg.unconditional_type,
target_height=getattr(cfg, "target_height", None),
target_width=getattr(cfg, "target_width", None),
image_cfg_scale=getattr(cfg, "image_cfg_scale", 1.0),
use_differential_sampling=cfg.sampling_params["use_differential_sampling"],
text_top_k=cfg.sampling_params["text_top_k"],
text_top_p=cfg.sampling_params["text_top_p"],
text_temperature=cfg.sampling_params["text_temperature"],
image_top_k=cfg.sampling_params["image_top_k"],
image_top_p=cfg.sampling_params["image_top_p"],
image_temperature=cfg.sampling_params["image_temperature"],
force_same_image_size=force_same_image_size,
)
return logits_processor
@torch.no_grad()
def multimodal_decode(
outputs,
tokenizer,
vision_tokenizer,
):
outputs = outputs.replace("<|extra_101|>", "").replace("<|extra_204|>", "")
pattern = re.compile(
rf"({re.escape(tokenizer.bog_token)}.*?{re.escape(tokenizer.eog_token)}|"
rf"{re.escape(tokenizer.boc_token)}.*?{re.escape(tokenizer.eoc_token)}|"
rf"{re.escape(tokenizer.boi_token)}.*?{re.escape(tokenizer.eoi_token)})",
re.DOTALL
)
multimodal_output = []
chunks = re.split(pattern, outputs)
for c in chunks:
if len(c) == 0:
continue
if tokenizer.boi_token in c and tokenizer.eoi_token in c:
image = decode_image(c, tokenizer, vision_tokenizer)
if image is not None:
multimodal_output.append(("image", image))
elif tokenizer.bog_token in c and tokenizer.eog_token in c:
multimodal_output.append(
("global_cot", c.replace(tokenizer.bog_token, "").replace(tokenizer.eog_token, ""))
)
elif tokenizer.boc_token in c and tokenizer.eoc_token in c:
multimodal_output.append(
("image_cot", c.replace(tokenizer.boc_token, "").replace(tokenizer.eoc_token, ""))
)
# exclude incomplete image
elif tokenizer.boi_token not in c and len(c.strip()) > 0:
multimodal_output.append(("text", c))
return multimodal_output
def decode_image(image_string, tokenizer, vision_tokenizer):
image = []
image_rows = re.split(re.escape(tokenizer.eol_token), image_string)
for r in image_rows:
token_ids = re.findall(r"<\|visual token (\d+)\|>", r)
if len(token_ids) > 0:
row_token = [int(m) for m in token_ids]
image.append(row_token)
try:
image = torch.tensor(image, dtype=torch.long, device=next(iter(vision_tokenizer.parameters())).device)
h, w = image.shape
image = vision_tokenizer.decode_code(image[None], shape=(1, h, w, 256)).float()
image = image[0].permute(1, 2, 0)
image = Image.fromarray(((image + 1.0) * 127.5).clamp(0, 255).detach().cpu().numpy().astype(np.uint8))
return image
except Exception as ex:
print(f"decode image failed {ex}")
return None
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from PIL import Image
import torch
import numpy as np
def smart_resize(image: Image.Image, area: int = 512 * 512, ds_factor: int = 16):
width, height = image.size
aspect_ratio = width / height
new_height = int((area / aspect_ratio) ** 0.5)
new_width = int(new_height * aspect_ratio)
# Round to nearest multiple of divisible_by
new_height = ((new_height + ds_factor//2) // ds_factor) * ds_factor
new_width = ((new_width + ds_factor//2) // ds_factor) * ds_factor
return image.resize((new_width, new_height), Image.BICUBIC)
def format_image_string(tokenizer, image_tokens):
image_string = ""
h, w = image_tokens.shape
for _h in range(h):
row_string = ""
for _w in range(w):
row_string += "<|visual token {token_id:0>6d}|>".format(token_id=image_tokens[_h, _w])
if _h < h - 1:
row_string += tokenizer.eol_token
image_string += row_string
return "{image_start}{token_height}*{token_width}{image_token}{token_str}{image_end}".format(
image_start=tokenizer.boi_token,
token_height=h,
token_width=w,
image_token=tokenizer.img_token,
token_str=image_string,
image_end=tokenizer.eoi_token,
)
@torch.no_grad()
def build_image(image, cfg, tokenizer, vq_model):
image = smart_resize(image, cfg.image_area)
w, h = image.size
device = next(vq_model.parameters()).device
dtype = next(vq_model.parameters()).dtype
image = torch.tensor((np.array(image) / 127.5 - 1.0)).to(device, dtype).permute(2, 0, 1)
_, _, token = vq_model.encode(image[None])
token = token[-1].view(h // 16, w // 16)
return format_image_string(tokenizer, token)
# -*- coding:utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import datetime
import os.path as osp
import builtins
old_print = builtins.print
def setup_print_file(file):
def print(*args, **kwargs):
msg = " ".join(map(str, args))
with open(file, "a") as f:
f.write(msg + "\n")
old_print(msg)
builtins.print = print
def setup_logger(log_dir="./", log_name="log"):
logfile = osp.join(
log_dir,
f'{log_name}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log',
)
os.makedirs(osp.dirname(logfile), exist_ok=True)
setup_print_file(logfile)
This diff is collapsed.
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os.path as osp
import torch
from transformers import AutoTokenizer
from ..emu3p5 import Emu3ForCausalLM, Emu3Config
from ..vision_tokenizer import build_vision_tokenizer
def build_emu3p5(
model_path,
tokenizer_path,
vq_path,
vq_type="ibq",
model_device="auto",
vq_device="cuda:0",
**kwargs,
):
if isinstance(model_device, int):
device_map = f"cuda:{model_device}"
else:
device_map = model_device
print(device_map)
# MLLM
model_config = Emu3Config.from_pretrained(
model_path,
trust_remote_code=True,
)
model = Emu3ForCausalLM.from_pretrained(
model_path,
config=model_config,
torch_dtype=torch.bfloat16,
device_map=device_map,
attn_implementation="flash_attention_2",
# attn_implementation="eager", # if you cann't install flash_attention
)
model.eval()
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
special_tokens_file=osp.join(tokenizer_path, "emu3_vision_tokens.txt"),
trust_remote_code=True,
)
tokenizer.bos_token = "<|extra_203|>"
tokenizer.eos_token = "<|extra_204|>"
tokenizer.pad_token = "<|endoftext|>"
tokenizer.eol_token = "<|extra_200|>"
tokenizer.eof_token = "<|extra_201|>"
tokenizer.tms_token = "<|extra_202|>"
tokenizer.img_token = "<|image token|>"
tokenizer.boi_token = "<|image start|>"
tokenizer.eoi_token = "<|image end|>"
tokenizer.bss_token = "<|extra_100|>"
tokenizer.ess_token = "<|extra_101|>"
tokenizer.bog_token = "<|extra_60|>"
tokenizer.eog_token = "<|extra_61|>"
tokenizer.boc_token = "<|extra_50|>"
tokenizer.eoc_token = "<|extra_51|>"
# vq tokenizer
vq_model = build_vision_tokenizer(vq_type, vq_path, device=vq_device, **kwargs)
return model, tokenizer, vq_model
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import re
import os.path as osp
from PIL import Image
from src.proto import emu_pb as story_pb
class ProtoWriter:
def __init__(self):
self.story = story_pb.Story()
self.image_tensor = None
def clear(self):
self.story = story_pb.Story()
self.image_tensor = None
def extend(self, multimodal_output):
for t, c in multimodal_output:
match t:
case "question":
self.story.question = c
case "global_cot":
self.story.summary = c
case "image_cot":
image = story_pb.ImageMeta()
image.chain_of_thought = c
self._put_last_image(image)
case "text":
self._put_last_clip(self._build_clip(c))
case "image":
image = self._get_last_image()
image.image.CopyFrom(self._build_image(c))
self._put_last_image(image)
case "reference_image":
image = story_pb.ImageMeta()
image.image.CopyFrom(self._build_image(c))
self.story.reference_images.append(image)
case _:
raise NotImplementedError(f"Unsupported data type {t}")
def save(self, path):
self._check_last_image()
with open(path, 'wb') as f:
f.write(self.story.SerializeToString())
def _build_clip(self, text_content=""):
clip = story_pb.Clip()
clip.clip_id = f"clip_{len(self.story.clips):04d}"
segment = story_pb.Segment()
segment.asr = text_content
clip.segments.append(segment)
return clip
def _build_image(self, image):
im = story_pb.Image()
im.width, im.height = image.size
im.format = story_pb.ImageFormat.PNG
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
im.image_data = img_byte_arr.getvalue()
return im
def _get_last_image(self):
if not self.story.clips:
self._put_last_clip(self._build_clip())
if self.story.clips[-1].segments[0].images and not self.story.clips[-1].segments[0].images[-1].image.image_data:
image = self.story.clips[-1].segments[0].images[-1]
del self.story.clips[-1].segments[0].images[-1]
else:
image = story_pb.ImageMeta()
return image
def _put_last_image(self, image):
if not self.story.clips:
self._put_last_clip(self._build_clip())
self.story.clips[-1].segments[0].images.append(image)
def _put_last_clip(self, clip):
self.story.clips.append(clip)
def _check_last_image(self):
image = self._get_last_image()
if image.image.image_data:
self._put_last_image(image)
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os
import os.path as osp
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import imageio
def wrap_text(draw, text, font, max_width):
lines = []
current_line = ""
i = 0
while i < len(text):
char = text[i]
test_line = current_line + char
bbox = draw.textbbox((0, 0), test_line, font=font)
text_width = bbox[2] - bbox[0]
if text_width <= max_width:
current_line = test_line
i += 1
else:
if current_line:
lines.append(current_line)
current_line = ""
else:
current_line = char
i += 1
if current_line:
lines.append(current_line)
return lines
def plot_string(string, font_path="src/proto/assets/cangerjinkai.ttf", font_size=80, image_size=(512, 512), bg_color="white", text_color="black"):
img = Image.new("RGB", image_size, color=bg_color)
draw = ImageDraw.Draw(img)
margin = 100
max_width = max(image_size[0] - 2 * margin, 1)
max_height = max(image_size[1] - 2 * margin, 1)
def load_font(size):
if font_path:
try:
return ImageFont.truetype(font_path, size)
except Exception:
print(f"Failed to load font from {font_path}")
return ImageFont.load_default()
font = load_font(font_size)
lines = wrap_text(draw, string, font, max_width)
line_height = draw.textbbox((0, 0), "Ay", font=font)[3]
total_text_height = line_height * max(len(lines), 1)
if total_text_height > max_height:
for size in range(font_size - 2, 9, -2):
font = load_font(size)
lines = wrap_text(draw, string, font, max_width)
line_height = draw.textbbox((0, 0), "Ay", font=font)[3]
total_text_height = line_height * max(len(lines), 1)
if total_text_height <= max_height:
break
else:
font = ImageFont.load_default()
lines = wrap_text(draw, string, font, max_width)
line_height = draw.textbbox((0, 0), "Ay", font=font)[3]
total_text_height = line_height * max(len(lines), 1)
y_offset = max(margin, (image_size[1] - total_text_height) // 2)
for line in lines:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
x_offset = max(margin, (image_size[0] - text_width) // 2)
draw.text((x_offset, y_offset), line, fill=text_color, font=font)
y_offset += line_height
return np.array(img)
def save_image_list_to_video(images, path, fps=1, quality='high'):
os.makedirs(osp.dirname(path), exist_ok=True)
if '.mp4' not in path and len(images) == 1:
img = images[0]
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy().astype(np.uint8)
elif isinstance(img, Image.Image):
img = np.array(img).astype(np.uint8)
else:
img = img.astype(np.uint8)
Image.fromarray(img).save(path, quality=100)
return
func = lambda x: (
x.detach().cpu().numpy().astype(np.uint8)
if isinstance(x, torch.Tensor)
else x.astype(np.uint8)
)
images = list(map(func, images))
if quality == 'high':
try:
writer = imageio.get_writer(
path,
fps=fps,
codec='libx264',
ffmpeg_params=[
'-crf', '18',
'-preset', 'slow',
'-pix_fmt', 'yuv420p',
]
)
for image in images:
writer.append_data(image)
writer.close()
except (TypeError, AttributeError):
try:
writer = imageio.get_writer(path, fps=fps, codec='libx264', macro_block_size=None)
for image in images:
writer.append_data(image)
writer.close()
except Exception:
with imageio.get_writer(path, fps=fps, mode='I') as writer:
for image in images:
writer.append_data(image)
else:
with imageio.get_writer(path, fps=fps, mode='I') as writer:
for image in images:
writer.append_data(image)
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os
import sys
import argparse
import numpy as np
from PIL import Image
import io
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from src.proto import emu_pb as story_pb
from src.utils.video_utils import plot_string, save_image_list_to_video
def main():
parser = argparse.ArgumentParser(description='Visualize protobuf story files')
parser.add_argument('--input', '-i', required=True, help='Input protobuf file path')
parser.add_argument('--output', '-o', required=True, help='Output directory path')
parser.add_argument('--video', action='store_true', help='Generate video from protobuf content')
parser.add_argument('--fps', type=int, default=1, help='Frames per second for video (default: 1)')
args = parser.parse_args()
input_path = args.input
output_path = args.output
os.makedirs(output_path, exist_ok=True)
with open(input_path, 'rb') as f:
story = story_pb.Story()
story.ParseFromString(f.read())
with open(f"{output_path}/000_question.txt", 'w') as f:
print(story.question, file=f)
if story.summary and story.summary.strip():
with open(f"{output_path}/000_global_cot.txt", 'w') as f:
print(story.summary, file=f)
idx = 1
if len(story.reference_images) > 0:
for i in range(len(story.reference_images)):
with open(f"{output_path}/{i:03d}_reference_image.png", 'wb') as f:
f.write(story.reference_images[i].image.image_data)
idx = len(story.reference_images)
for c in story.clips:
for s in c.segments:
with open(f"{output_path}/{idx:03d}_text.txt", 'w') as f:
print(s.asr, file=f)
for im_idx, im in enumerate(s.images):
with open(f"{output_path}/{idx:03d}_{im_idx:02d}_image.png", 'wb') as f:
f.write(im.image.image_data)
if im.chain_of_thought and im.chain_of_thought.strip():
with open(f"{output_path}/{idx:03d}_{im_idx:02d}_image_cot.txt", 'w') as f:
print(im.chain_of_thought, file=f)
idx += 1
if args.video:
video_images = []
target_size = None
for ref_img_data in story.reference_images:
img = Image.open(io.BytesIO(ref_img_data.image.image_data))
img = img.convert('RGB')
if target_size is None:
target_size = img.size
for c in story.clips:
for s in c.segments:
for im in s.images:
img = Image.open(io.BytesIO(im.image.image_data))
img = img.convert('RGB')
if target_size is None:
target_size = img.size
if target_size is None:
target_size = (512, 512)
if story.question and story.question.strip():
question_img = plot_string(story.question, image_size=(target_size[0], target_size[1]))
video_images.append(question_img)
for img_array in story.reference_images:
img = Image.open(io.BytesIO(img_array.image.image_data))
img = img.convert('RGB')
if img.size != target_size:
img = img.resize(target_size, Image.Resampling.LANCZOS)
video_images.append(np.array(img))
for c in story.clips:
for s in c.segments:
if s.asr and s.asr.strip():
asr_img = plot_string(s.asr, image_size=(target_size[0], target_size[1]))
video_images.append(asr_img)
for im in s.images:
img = Image.open(io.BytesIO(im.image.image_data))
img = img.convert('RGB')
if img.size != target_size:
img = img.resize(target_size, Image.Resampling.LANCZOS)
video_images.append(np.array(img))
if video_images:
video_path = f"{output_path}/video.mp4"
save_image_list_to_video(video_images, video_path, fps=args.fps, quality='high')
print(f"Video saved to: {video_path}")
if __name__ == "__main__":
main()
import os.path as osp
from omegaconf import OmegaConf
import torch
from .ibq import IBQ
def build_vision_tokenizer(type, model_path, device="cuda:0", **kwargs):
match type:
case "ibq":
cfg = OmegaConf.load(osp.join(model_path, "config.yaml"))
tokenizer = IBQ(**cfg)
ckpt = torch.load(osp.join(model_path, "model.ckpt"), map_location="cpu")
tokenizer.load_state_dict(ckpt)
tokenizer.eval().to(device)
return tokenizer
case _:
raise NotImplementedError(f"Unsupported vision tokenizer type: {type}")
# -*- coding:utf-8 -*-
from torch import nn
from .modules.diffusionmodules.model import Encoder, Decoder
from .modules.vqvae.quantize import IndexPropagationQuantize
class IBQ(nn.Module):
def __init__(
self,
ddconfig,
n_embed,
embed_dim,
beta=0.25,
use_entropy_loss=False,
cosine_similarity=False,
entropy_temperature=0.01,
sample_minimization_weight=1.0,
batch_maximization_weight=1.0,
**kwargs,
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quantize = IndexPropagationQuantize(
n_embed,
embed_dim,
beta,
use_entropy_loss,
cosine_similarity=cosine_similarity,
entropy_temperature=entropy_temperature,
sample_minimization_weight=sample_minimization_weight,
batch_maximization_weight=batch_maximization_weight,
)
self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant, return_intermediate_feature=False):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, return_intermediate_feature=return_intermediate_feature)
return dec
def decode_code(self, code_b, shape=None):
# shape specifying (batch, height, width, channel)
quant_b = self.quantize.get_codebook_entry(code_b, shape=shape)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_intermediate_feature=False):
quant, diff, _ = self.encode(input)
dec = self.decode(quant, return_intermediate_feature=return_intermediate_feature)
return dec, diff
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment