run.py 3.72 KB
Newer Older
raojy's avatar
first  
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import random
import numpy as np
import torch
from diffusers import ErnieImagePipeline
from tqdm import tqdm

# 设置全局随机种子确保可复现性
# seed = 42
seed = random.randint(0, 100000)
print(f"seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# 在 DCU 上,torch.cuda.manual_seed_all 会自动映射到底层 hipRAND
torch.cuda.manual_seed_all(seed)

# 允许一定的算子融合和自动寻优,DCU 的 MIOpen 会接管 cudnn 的设置
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 加载 pipeline
# 注意:如果你的 DCU 版本(如某些较老的型号)对 bfloat16 支持不佳,可以尝试换成 torch.float16
pipe = ErnieImagePipeline.from_pretrained(
raojy's avatar
raojy committed
25
    "baidu/ERNIE-Image",
raojy's avatar
first  
raojy committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    torch_dtype=torch.bfloat16 
)

# DCU 版本的 PyTorch 会自动将 "cuda" 映射到 DCU 设备上
pipe = pipe.to("cuda")

pipe.transformer.eval() 
pipe.vae.eval()
pipe.text_encoder.eval()
pipe.pe.eval()

# 如果显存不够可以开启 offload
# pipe.enable_model_cpu_offload()

# 设置随机种子,"cuda" 在这里同样会被映射到 DCU
generator = torch.Generator(device="cuda").manual_seed(seed)

# 确保输出目录存在
os.makedirs("../tests", exist_ok=True)

# 生成图片
prompt_list = [
    "A highly detailed biological pathway diagram in BioRender style. Depicting the viral infection process of human immune cells. Showing a virus particle attaching to a T-cell receptor, viral RNA replicating inside the cell nucleus, and the cell transforming into a malignant tumor cell. Includes molecular signaling pathways, proteins, and epigenetic modification symbols. Scientific flat vector style, soft pastel medical color palette, clean white background, educational graphic, crisp lines, professional scientific journal illustration. --ar 16:9 --v 6.0",
    "建筑坐落在住宅小区道路旁,被高大浓密的绿色乔木包围,树冠形成自然遮荫空间,严格保持图中所有元素的一致性。 午后自然阳光透过树叶形成斑驳光影(dappled sunlight),光线柔和且具有层次,地面呈现清晰树影,空气通透,微风感 建筑下方为开放式咖啡空间,人群自然分布,轻松社交状态,室内外边界模糊,空间通透流动 摄影级真实渲染,超高动态范围(HDR),自然曝光,真实反射与折射,玻璃微反射环境,极致细节。SANAA建筑风格,极简主义,轻盈漂浮感,日式当代建筑语言,自然主义建筑融合 高级建筑摄影风格,类似Dezeen / ArchDaily封面级别,纪实但理想化 电影级光影(cinematic natural lighting),柔和高光不过曝,阴影细腻。tree canopy filtered sunlight, soft shadows, volumetric light subtle, natural ambient occlusion 阳光穿过树叶产生细碎光斑,地面光影随机分布,光影边界柔和不过硬 色反射光(green bounce light)轻微影响建筑底部色温。广角镜头 24mm,低机位轻微仰视,前景草地,中景建筑,背景树林 画面有树枝作为前景遮挡(frame foreground leaves),增加空间层次 景深适中,整体清晰但有空气透视。--quality 2 --style raw --ar 3:4 --lighting natural --render photorealistic --detail ultra",
    "A photograph of the Straw Hat Pirates drawn on a glass whiteboard with a faded green marker, front view, 4K resolution."
]

for idx, prompt in enumerate(prompt_list):
    output = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=5.0,
        generator=generator,
    )
    revised_prompt = output.revised_prompts
    images = output.images
    images[0].save(f"../tests/hf_output{idx+1}.png")
raojy's avatar
raojy committed
65
    print(f"Prompt {idx+1} revised: {revised_prompt}")