Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from transformers.generation.streamers import TextIteratorStreamer
from PIL import Image
import requests
from io import BytesIO
from cog import BasePredictor, Input, Path, ConcatenateIterator
import time
import subprocess
from threading import Thread
import os
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
# url for the weights mirror
REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
# files to download from the weights mirrors
weights = [
{
"dest": "liuhaotian/llava-v1.5-13b",
# git commit hash from huggingface
"src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
"files": [
"config.json",
"generation_config.json",
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
"pytorch_model.bin.index.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
],
},
{
"dest": "openai/clip-vit-large-patch14-336",
"src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
"files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"],
},
]
def download_json(url: str, dest: Path):
res = requests.get(url, allow_redirects=True)
if res.status_code == 200 and res.content:
with dest.open("wb") as f:
f.write(res.content)
else:
print(f"Failed to download {url}. Status code: {res.status_code}")
def download_weights(baseurl: str, basedest: str, files: list[str]):
basedest = Path(basedest)
start = time.time()
print("downloading to: ", basedest)
basedest.mkdir(parents=True, exist_ok=True)
for f in files:
dest = basedest / f
url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
if not dest.exists():
print("downloading url: ", url)
if dest.suffix == ".json":
download_json(url, dest)
else:
subprocess.check_call(["pget", url, str(dest)], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
for weight in weights:
download_weights(weight["src"], weight["dest"], weight["files"])
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False)
def predict(
self,
image: Path = Input(description="Input image"),
prompt: str = Input(description="Prompt to use for text generation"),
top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0),
temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0),
max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0),
) -> ConcatenateIterator[str]:
"""Run a single prediction on the model"""
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
image_data = load_image(str(image))
image_tensor = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"].half().cuda()
# loop start
# just one turn, always prepend image token
inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0)
with torch.inference_mode():
thread = Thread(
target=self.model.generate,
kwargs=dict(inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]),
)
thread.start()
# workaround: second-to-last token is always " "
# but we want to keep it if it's not the second-to-last token
prepend_space = False
for new_text in streamer:
if new_text == " ":
prepend_space = True
continue
if new_text.endswith(stop_str):
new_text = new_text[: -len(stop_str)].strip()
prepend_space = False
elif prepend_space:
new_text = " " + new_text
prepend_space = False
if len(new_text):
yield new_text
if prepend_space:
yield " "
thread.join()
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
[tool.black]
line-length = 240
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "llava"
version = "1.7.0.dev0"
description = "LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities"
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
[project.optional-dependencies]
standalone = [
"shortuuid",
"httpx==0.24.0",
"einops",
"ftfy",
]
train = [
"llava[standalone]",
"numpy==1.26.1",
"open_clip_torch",
"fastapi",
"markdown2[all]",
"numpy",
"requests",
"sentencepiece",
# "torch==2.1.2",
# "torchvision==0.16.2",
"uvicorn",
"wandb",
# "deepspeed==0.14.4",
"peft==0.4.0",
"accelerate>=0.29.1",
"tokenizers~=0.15.2",
# "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4",
# "bitsandbytes==0.41.0",
"scikit-learn==1.2.2",
"sentencepiece~=0.1.99",
"einops==0.6.1",
"einops-exts==0.0.4",
"gradio_client==0.2.9",
"urllib3<=2.0.0",
"datasets==2.16.1",
"pydantic==1.10.8",
"timm",
"hf_transfer",
"opencv-python",
"av",
"decord",
"tyro",
"scipy",
]
[project.urls]
"Homepage" = "https://llava-vl.github.io"
"Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
[tool.setuptools.packages.find]
include = ["llava*", "trl*"]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"checkpoints*",
"project_checkpoints*",
"debug_checkpoints*",
"mlx_configs*",
"wandb*",
"notebooks*",
]
[tool.wheel]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"checkpoints*",
"project_checkpoints*",
"debug_checkpoints*",
"mlx_configs*",
"wandb*",
"notebooks*",
]
Babel==2.14.0
DataProperty==1.0.1
Deprecated==1.2.14
GitPython==3.1.43
Jinja2==3.1.3
Levenshtein==0.25.1
MarkupSafe==2.1.5
PyJWT==2.8.0
PyYAML==6.0.1
Pygments==2.17.2
QtPy==2.4.1
Send2Trash==1.8.3
absl-py==2.1.0
accelerate==0.29.3
aiofiles==22.1.0
aiohttp==3.9.5
aiosignal==1.3.1
aiosqlite==0.20.0
altair==5.3.0
anyio==4.3.0
appdirs==1.4.4
argon2-cffi-bindings==21.2.0
argon2-cffi==23.1.0
arrow==1.3.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.1.0
beautifulsoup4==4.12.3
bidict==0.23.1
# bitsandbytes==0.41.0
black==24.1.0
bleach==6.1.0
byted-remote-ikernel==0.4.8
byted-torch-monitor==0.0.1
byted-wandb==0.13.72
bytedance-context==0.7.1
bytedance-metrics==0.5.1
bytedance.modelhub==0.0.64
bytedance.servicediscovery==0.1.2
bytedbackgrounds==0.0.6
byteddatabus==1.0.6
byteddps==0.1.2
bytedenv==0.6.2
bytedlogger==0.15.1
bytedmemfd==0.2
bytedmetrics==0.10.2
bytedpymongo==2.0.5
bytedrh2==1.18.7a2
bytedservicediscovery==0.17.4
bytedtcc==1.4.2
bytedtos==1.1.16
bytedtrace==0.3.0
bytedztijwthelper==0.0.22
bytedztispiffe==0.0.11
certifi==2024.2.2
cffi==1.16.0
cfgv==3.4.0
chardet==5.2.0
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
comm==0.2.2
contourpy==1.2.1
crcmod==1.7
cryptography==38.0.4
cycler==0.12.1
datasets==2.16.1
debugpy==1.8.1
decorator==5.1.1
decord==0.6.0
deepspeed==0.12.2
defusedxml==0.7.1
dill==0.3.7
distlib==0.3.8
distro==1.9.0
dnspython==2.6.1
docker-pycreds==0.4.0
docstring_parser==0.16
einops-exts==0.0.4
einops==0.6.1
entrypoints==0.4
et-xmlfile==1.1.0
eval_type_backport==0.2.0
evaluate==0.4.1
exceptiongroup==1.2.1
executing==2.0.1
fastapi==0.110.2
fastjsonschema==2.19.1
ffmpy==0.3.2
filelock==3.13.4
# flash-attn==2.5.7
fonttools==4.51.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2023.10.0
ftfy==6.2.0
gitdb==4.0.11
gradio==3.35.2
gradio_client==0.2.9
grpcio==1.62.2
h11==0.14.0
hf_transfer==0.1.6
hjson==3.1.0
httpcore==0.17.3
httpx==0.24.0
huggingface-hub==0.22.2
identify==2.5.36
idna==3.7
importlib_metadata==7.1.0
importlib_resources==6.4.0
iniconfig==2.0.0
ipaddress==1.0.23
ipykernel==6.29.4
ipython-genutils==0.2.0
ipython==8.18.1
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.19.1
joblib==1.4.0
json5==0.9.25
jsonlines==4.0.0
jsonpointer==2.4
jsonschema-specifications==2023.12.1
jsonschema==4.21.1
jupyter-client==7.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-ydoc==0.2.5
jupyter==1.0.0
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_fileid==0.9.2
jupyter_server_terminals==0.5.3
jupyter_server_ydoc==0.8.0
jupyterlab==3.6.4
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
kiwisolver==1.4.5
linkify-it-py==2.0.3
llava==1.7.0.dev0
llava==1.7.0.dev0
lmms_eval==0.1.1
lxml==5.2.1
markdown-it-py==2.2.0
markdown2==2.4.13
matplotlib-inline==0.1.7
matplotlib==3.8.4
mbstrdecoder==1.1.3
mdit-py-plugins==0.3.3
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multiprocess==0.70.15
mypy-extensions==1.0.0
nbclassic==1.0.0
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
nodeenv==1.8.0
notebook==6.5.6
notebook_shim==0.2.4
numexpr==2.10.0
numpy==1.26.4
# nvidia-cublas-cu12==12.1.3.1
# nvidia-cuda-cupti-cu12==12.1.105
# nvidia-cuda-nvrtc-cu12==12.1.105
# nvidia-cuda-runtime-cu12==12.1.105
# nvidia-cudnn-cu12==8.9.2.26
# nvidia-cufft-cu12==11.0.2.54
# nvidia-curand-cu12==10.3.2.106
# nvidia-cusolver-cu12==11.4.5.107
# nvidia-cusparse-cu12==12.1.0.106
# nvidia-nccl-cu12==2.18.1
# nvidia-nvjitlink-cu12==12.4.127
# nvidia-nvtx-cu12==12.1.105
open-clip-torch==2.24.0
openai==1.23.6
opencv-python-headless==4.9.0.80
openpyxl==3.1.2
orjson==3.10.1
overrides==7.7.0
packaging==24.0
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pathlib2==2.3.7.post1
pathspec==0.12.1
pathtools==0.1.2
pathvalidate==3.2.0
peft==0.4.0
pexpect==4.8.0
pillow==10.3.0
pip==23.3.1
pip==24.0
platformdirs==4.2.1
pluggy==1.5.0
ply==3.11
portalocker==2.8.2
pre-commit==3.7.0
prometheus_client==0.20.0
promise==2.3
prompt-toolkit==3.0.43
protobuf==3.20.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py-spy==0.3.14
py==1.11.0
pyOpenSSL==22.1.0
pyarrow-hotfix==0.6
pyarrow==16.0.0
pybind11==2.12.0
pycocoevalcap==1.2
pycocotools==2.0.7
pycparser==2.22
pycryptodomex==3.20.0
pydantic==1.10.8
pydub==0.25.1
pynvml==11.5.0
pyparsing==3.1.2
pytablewriter==1.2.0
pytest==6.2.5
python-consul==1.1.0
python-dateutil==2.9.0.post0
python-engineio==4.9.0
python-etcd==0.4.5
python-json-logger==2.0.7
python-multipart==0.0.9
python-socketio==5.11.2
pytz==2024.1
pyzmq==24.0.1
qtconsole==5.5.1
rapidfuzz==3.8.1
referencing==0.35.0
regex==2024.4.16
requests==2.31.0
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rouge-score==0.1.2
rpds-py==0.18.0
sacrebleu==2.4.2
safetensors==0.4.3
schedule==1.2.1
scikit-learn==1.2.2
scipy==1.13.0
semantic-version==2.10.0
sentencepiece==0.1.99
sentry-sdk==2.0.0
setproctitle==1.3.3
setuptools==68.2.2
shortuuid==1.0.13
shtab==1.7.1
simple-websocket==1.0.0
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
soupsieve==2.5
sqlitedict==2.1.0
stack-data==0.6.3
starlette==0.37.2
svgwrite==1.4.3
sympy==1.12
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.4
tenacity==8.2.3
terminado==0.18.1
threadpoolctl==3.4.0
thriftpy2==0.4.20
tiktoken==0.6.0
timm==0.9.16
tinycss2==1.3.0
tokenizers==0.15.2
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.1.2
torchvision==0.16.2
tornado==6.4
tox==3.28.0
tqdm-multiprocess==0.0.11
tqdm==4.66.2
traitlets==5.14.3
transformers-stream-generator==0.0.5
transformers==4.40.0.dev0
triton==2.1.0
typepy==1.3.2
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tyro==0.8.3
tzdata==2024.1
uc-micro-py==1.0.3
uri-template==1.3.0
urllib3==2.2.1
uvicorn==0.29.0
virtualenv==20.26.0
wandb==0.16.5
watchdog==4.0.0
wavedrom==2.0.3.post3
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
websockets==12.0
wheel==0.41.2
widgetsnbextension==4.0.10
wrapt==1.16.0
wsproto==1.2.0
xxhash==3.4.1
y-py==0.6.2
yarl==1.9.4
ypy-websocket==0.8.4
zipp==3.18.1
zstandard==0.22.0
import os
import json
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
parser.add_argument("--dst", type=str)
args = parser.parse_args()
all_answers = []
for line_idx, line in enumerate(open(args.src)):
res = json.loads(line)
question_id = res["question_id"]
text = res["text"].rstrip(".").lower()
all_answers.append({"questionId": question_id, "prediction": text})
with open(args.dst, "w") as f:
json.dump(all_answers, f)
import os
import json
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
parser.add_argument("--dst", type=str)
args = parser.parse_args()
cur_result = {}
for line in open(args.src):
data = json.loads(line)
qid = data["question_id"]
cur_result[f"v1_{qid}"] = data["text"]
with open(args.dst, "w") as f:
json.dump(cur_result, f, indent=2)
import json
import os
import fire
import re
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
problems = json.load(open(os.path.join(base_dir, "problems.json")))
split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
target_format = []
for prob_id, (input, output) in split_problems.items():
if input.startswith("Question: "):
input = input.replace("Question: ", "")
if output.startswith("Answer: "):
output = output.replace("Answer: ", "")
raw_prob_data = problems[prob_id]
if raw_prob_data["image"] is None:
target_format.append(
{
"id": prob_id,
"conversations": [
{"from": "human", "value": f"{input}"},
{"from": "gpt", "value": f"{output}"},
],
}
)
else:
target_format.append(
{
"id": prob_id,
"image": os.path.join(prob_id, raw_prob_data["image"]),
"conversations": [
{"from": "human", "value": f"{input}\n<image>"},
{"from": "gpt", "value": f"{output}"},
],
}
)
print(f"Number of samples: {len(target_format)}")
with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
json.dump(target_format, f, indent=2)
def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
problems = json.load(open(os.path.join(base_dir, "problems.json")))
split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
for prob_id, (input, output) in split_problems.items():
if input.startswith("Question: "):
input = input.replace("Question: ", "")
if output.startswith("Answer: "):
output = output.replace("Answer: ", "")
raw_prob_data = problems[prob_id]
if raw_prob_data["image"] is None:
data = {
"id": prob_id,
"instruction": f"{input}",
"output": f"{output}",
}
else:
data = {
"id": prob_id,
"image": os.path.join(prob_id, raw_prob_data["image"]),
"instruction": f"{input}\n<image>",
"output": f"{output}",
}
writer.write(json.dumps(data) + "\n")
writer.close()
def main(task, **kwargs):
globals()[task](**kwargs)
if __name__ == "__main__":
fire.Fire(main)
def get_question_text(problem):
question = problem["question"]
return question
def get_context_text(problem, use_caption):
txt_context = problem["hint"]
img_context = problem["caption"] if use_caption else ""
context = " ".join([txt_context, img_context]).strip()
if context == "":
context = "N/A"
return context
def get_choice_text(probelm, options):
choices = probelm["choices"]
choice_list = []
for i, c in enumerate(choices):
choice_list.append("({}) {}".format(options[i], c))
choice_txt = " ".join(choice_list)
# print(choice_txt)
return choice_txt
def get_answer(problem, options):
return options[problem["answer"]]
def get_lecture_text(problem):
# \\n: GPT-3 can generate the lecture with more tokens.
lecture = problem["lecture"].replace("\n", "\\n")
return lecture
def get_solution_text(problem):
# \\n: GPT-3 can generate the solution with more tokens
solution = problem["solution"].replace("\n", "\\n")
return solution
def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == "A":
output = f"Answer: The answer is {answer}."
elif output_format == "AL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == "AE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == "ALE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == "AEL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == "LA":
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == "EA":
output = f"Answer: {solution} The answer is {answer}."
elif output_format == "LEA":
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == "ELA":
output = f"Answer: {solution} {lecture} The answer is {answer}."
elif output_format == "LEPA":
output = ""
if len(lecture.strip()) > 0:
output += f"LECTURE: {lecture}\n"
if len(solution.strip()) > 0:
output += f"SOLUTION: {solution}\n"
output += "###\n"
output += f"ANSWER: {answer}."
input = input.replace(" ", " ").strip()
output = output.replace(" ", " ").strip()
if input.endswith("BECAUSE:"):
input = input.replace("BECAUSE:", "").strip()
if output.endswith("BECAUSE:"):
output = output.replace("BECAUSE:", "").strip()
return input, output
def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == "A":
output = f"Answer: The answer is {answer}."
elif output_format == "AL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == "AE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == "ALE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == "AEL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == "LA":
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == "EA":
output = f"Answer: {solution} The answer is {answer}."
elif output_format == "LEA":
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == "ELA":
output = f"Answer: {solution} {lecture} The answer is {answer}."
text = input + output
text = text.replace(" ", " ").strip()
if text.endswith("BECAUSE:"):
text = text.replace("BECAUSE:", "").strip()
return text
def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
input_format, output_format = format.split("-")
## Inputs
if input_format == "CQM":
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
elif input_format == "QCM":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
# upper bound experiment
elif input_format == "QCML":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
elif input_format == "QCME":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
elif input_format == "QCMLE":
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
elif input_format == "QCLM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
elif input_format == "QCEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
elif input_format == "QCLEM":
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
# Outputs
if test_example:
output = "Answer:"
elif output_format == "A":
output = f"Answer: The answer is {answer}."
elif output_format == "AL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
elif output_format == "AE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
elif output_format == "ALE":
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
elif output_format == "AEL":
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
elif output_format == "LA":
output = f"Answer: {lecture} The answer is {answer}."
elif output_format == "EA":
output = f"Answer: {solution} The answer is {answer}."
elif output_format == "LEA":
output = f"Answer: {lecture} {solution} The answer is {answer}."
elif output_format == "ELA":
output = f"Answer: {solution} {lecture} The answer is {answer}."
input = input.replace(" ", " ").strip()
output = output.replace(" ", " ").strip()
if output.endswith("BECAUSE:"):
output = output.replace("BECAUSE:", "").strip()
user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
assistant_prompt = {"role": "assistant", "content": f"{output}"}
return user_prompt, assistant_prompt
def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
examples = {}
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], use_caption)
choice = get_choice_text(problems[qid], options)
answer = get_answer(problems[qid], options)
lecture = get_lecture_text(problems[qid]).replace("\\n", "\n")
solution = get_solution_text(problems[qid]).replace("\\n", "\n")
train_example = create_one_example_chatbot(prompt_format, question, context, choice, answer, lecture, solution, test_example=is_test)
examples[qid] = train_example
return examples
def build_prompt(problems, shot_qids, test_qid, args):
examples = []
# n-shot training examples
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], args.use_caption)
choice = get_choice_text(problems[qid], args.options)
answer = get_answer(problems[qid], args.options)
lecture = get_lecture_text(problems[qid])
solution = get_solution_text(problems[qid])
train_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False)
examples.append(train_example)
# test example
question = get_question_text(problems[test_qid])
context = get_context_text(problems[test_qid], args.use_caption)
choice = get_choice_text(problems[test_qid], args.options)
answer = get_answer(problems[test_qid], args.options)
lecture = get_lecture_text(problems[test_qid])
solution = get_solution_text(problems[test_qid])
test_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True)
examples.append(test_example)
# create the prompt input
prompt_input = "\n\n".join(examples)
return prompt_input
def build_prompt_gpt4(problems, shot_qids, test_qid, args):
prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
# n-shot training examples
for qid in shot_qids:
question = get_question_text(problems[qid])
context = get_context_text(problems[qid], args.use_caption)
choice = get_choice_text(problems[qid], args.options)
answer = get_answer(problems[qid], args.options)
lecture = get_lecture_text(problems[qid])
solution = get_solution_text(problems[qid])
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False)
prompt_array.append(user_prompt)
prompt_array.append(assistant_prompt)
# test example
question = get_question_text(problems[test_qid])
context = get_context_text(problems[test_qid], args.use_caption)
choice = get_choice_text(problems[test_qid], args.options)
answer = get_answer(problems[test_qid], args.options)
lecture = get_lecture_text(problems[test_qid])
solution = get_solution_text(problems[test_qid])
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True)
prompt_array.append(user_prompt)
prompt_array.append(assistant_prompt)
return prompt_array
import os
import argparse
import json
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--annotation-file", type=str, required=True)
parser.add_argument("--result-file", type=str, required=True)
parser.add_argument("--result-upload-file", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
results = []
error_line = 0
for line_idx, line in enumerate(open(args.result_file)):
try:
results.append(json.loads(line))
except:
error_line += 1
results = {x["question_id"]: x["text"] for x in results}
test_split = [json.loads(line) for line in open(args.annotation_file)]
split_ids = set([x["question_id"] for x in test_split])
print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
all_answers = []
answer_processor = EvalAIAnswerProcessor()
for x in test_split:
# import pdb; pdb.set_trace()
assert x["question_id"] in results, print(x)
all_answers.append({"image": x["image"], "answer": answer_processor(results[x["question_id"]])})
with open(args.result_upload_file, "w") as f:
json.dump(all_answers, f)
import os
import argparse
import json
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str, default="./playground/data/eval/vqav2")
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--split", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
src = os.path.join(args.dir, "answers", args.split, args.ckpt, "merge.jsonl")
test_split = os.path.join(args.dir, "llava_vqav2_mscoco_test2015.jsonl")
dst = os.path.join(args.dir, "answers_upload", args.split, f"{args.ckpt}.json")
os.makedirs(os.path.dirname(dst), exist_ok=True)
results = []
error_line = 0
for line_idx, line in enumerate(open(src)):
try:
results.append(json.loads(line))
except:
error_line += 1
results = {x["question_id"]: x["text"] for x in results}
test_split = [json.loads(line) for line in open(test_split)]
split_ids = set([x["question_id"] for x in test_split])
print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
all_answers = []
answer_processor = EvalAIAnswerProcessor()
for x in test_split:
if x["question_id"] not in results:
all_answers.append({"question_id": x["question_id"], "answer": ""})
else:
all_answers.append({"question_id": x["question_id"], "answer": answer_processor(results[x["question_id"]])})
with open(dst, "w") as f:
json.dump(all_answers, open(dst, "w"))
import json
import os
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
def load_data(json_path):
with open(json_path, "r") as f:
return json.load(f)
def filter_data(data):
# filtered_data = [item for item in data if "image" in item and "text" in item["image"]]
filtered_data = [item for item in data if "image" in item]
return filtered_data
from multiprocessing import Pool
import functools
def calculate_image_dimension(item, images_folder):
image_path = os.path.join(images_folder, item["image"])
try:
with Image.open(image_path) as img:
width, height = img.size
return width, height
except Exception as e:
print(f"Error opening {image_path}: {e}")
return None, None
def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256):
with Pool(num_processes) as p:
dimensions = list(tqdm(p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), filtered_data), total=len(filtered_data), desc="Calculating image dimensions"))
widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None])
return list(widths), list(heights)
def tokenize(text):
return text.split()
def calculate_tokenized_lengths(data):
lengths = []
for item in tqdm(data, desc="Tokenizing conversations"):
for conversation in item["conversations"]:
tokenized_value = tokenize(conversation["value"])
lengths.append(len(tokenized_value))
return lengths
import argparse
def main():
parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.")
parser.add_argument("--json_path", type=str, help="Path to the JSON file containing data.")
parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing images.")
args = parser.parse_args()
llava_instruct_name = args.json_path.split("/")[-1].replace(".json", "")
json_path = args.json_path
llava_instruct_name = os.path.basename(json_path).replace(".json", "")
images_folder = args.images_folder
data = load_data(json_path)
filtered_data = filter_data(data)
if len(filtered_data) != 0:
print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}")
widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder)
max_width = max(widths)
max_height = max(heights)
print(f"Max width: {max_width}, Max height: {max_height}")
tokenized_lengths = calculate_tokenized_lengths(data)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
if len(filtered_data) != 0:
# Plot 2D histogram
if min(widths) == max(widths):
widths_bins = [min(widths), max(widths) + 1]
else:
widths_bins = np.arange(min(widths), max(widths) + 100, 100)
if min(heights) == max(heights):
heights_bins = [min(heights), max(heights) + 1]
else:
heights_bins = np.arange(min(heights), max(heights) + 100, 100)
h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True)
fig.colorbar(image, ax=ax1)
ax1.set_xlabel("Width")
ax1.set_ylabel("Height")
ax1.set_title(f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}", fontsize=10)
# Plot histogram
hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 100))
bins = np.arange(0, max(tokenized_lengths) + 10, 100)
ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True)
# Display every nth label on the x-axis
n = 8 # Adjust this value to control the number of labels displayed
ticks = bins[::n]
tick_labels = [int(tick) for tick in ticks]
ax2.set_xticks(ticks)
ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8)
ax2.set_xlim(min(bin_edges), max(bin_edges))
ax2.set_xlabel("Tokenized Length")
ax2.set_ylabel("Count (log scale)")
ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8)
plt.tight_layout()
plt.savefig(f"/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png")
print(f"Plots saved to /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png")
if __name__ == "__main__":
main()
import json
import re
json_path = "/mnt/bn/vl-research/workspace/boli01/projects/sft_data_workspace/vlfeedback_80k.jsonl"
with open(json_path, "r") as f:
data = f.readlines()
data = [json.loads(d) for d in data]
def convert_format(original_data, dimension="Visual Faithfulness"):
converted_data = []
for item in original_data:
# Assuming the best response is the one with the highest helpfulness rating
best_completion = max(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"]))
best_response = best_completion["response"]
best_model = best_completion["model"]
if "†source" in best_response:
print(best_response)
# Regex pattern to match the pattern 【digit†source】
pattern = r"【\d+†source】"
# Replace the matched patterns with an empty string
cleaned_text = re.sub(pattern, "", best_response)
best_response = cleaned_text
print(f"*****************************************")
print(best_response)
# Assuming the worst response is the one with the lowest helpfulness rating
worst_completion = min(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"]))
worst_response = worst_completion["response"]
if "†source" in worst_response:
print(worst_response)
# Regex pattern to match the pattern ��digit†source】
pattern = r"【\d+†source】"
# Replace the matched patterns with an empty string
cleaned_text = re.sub(pattern, "", worst_response)
worst_response = cleaned_text
print(f"*****************************************")
print(worst_response)
# Extract scores
best_score = int(best_completion["annotations"][dimension]["Rating"])
worst_score = int(worst_completion["annotations"][dimension]["Rating"])
# Construct the new format
new_item = {
"id": item["id"],
"prompt": item["prompt"],
"answer": "",
"image": f"silkie_dpo/{item['id']}.jpg", # Assuming the video ID is the last part of the original ID
"chosen": best_response,
"rejected": worst_response,
"chosen_score": best_score,
"rejected_score": worst_score,
}
converted_data.append(new_item)
return converted_data
for dimension in ["Visual Faithfulness", "Helpfulness", "Ethical Considerations"]:
converted_data = convert_format(data, dimension=dimension)
with open(f"/mnt/bn/vl-research/data/llava_instruct/dpo_data/silkie_dpo_data_{dimension.replace(' ', '_').lower()}_{len(converted_data)}.json", "w") as f:
json.dump(converted_data, f, indent=4)
python3 -m pip install --upgrade pip;
export http_proxy=http://sys-proxy-rd-relay.byted.org:8118;
export https_proxy=http://sys-proxy-rd-relay.byted.org:8118;
export HF_HOME=/mnt/bn/vl-research-boli01-cn/.cache/huggingface;
export HF_TOKEN="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";
export HF_HUB_ENABLE_HF_TRANSFER="1";
cd /mnt/bn/vl-research-boli01-cn/projects/zzz/lmms-eval;
pip install -e .;
cd /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next;
pip install -e .;
python3 -m pip install ninja;
python3 -m pip install flash-attn --no-build-isolation;
bash /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/cn_scripts/vicuna/internal0.6m_finetune_llava1.6mix_7b_v0.2_unfreeze.sh
accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \
--model llava \
--model_args pretrained="/mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/internal_project_checkpoints/llavanext-lmsys_vicuna-7b-v1.5-clip-vit-large-patch14-336-mlp2x_gelu-pretrain_internal0.6m_vicuna_v1_finetune_llava1.6_datamix_unfreezeVIS_1e" \
--tasks ok_vqa,textcaps_val,mme_test,mmmu,cmmmu,coco2017_cap_val,vizwiz_vqa_val,ai2d,chartqa,pope \
--batch_size 1 \
--log_samples \
--log_samples_suffix debug \
--output_path ./logs/ \
--wandb_args 'project=llava-next-lmms-eval,job_type=eval';
\ No newline at end of file
#!/bin/bash
cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA
# Install yolk3k if not installed
if ! pip show yolk3k > /dev/null 2>&1; then
pip install yolk3k
fi
# Get the installed version of transformers
installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2)
# Get the latest version of transformers from PyPI
latest_version=$(yolk -V transformers | cut -d ' ' -f 2)
# Check if the installed version is not the latest
if [ "$installed_version" != "$latest_version" ]; then
pip install -U transformers
fi
# Get the installed version of deepspeed
installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2)
# Get the latest version of deepspeed from PyPI
latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2)
# Check if the installed version is not the latest
# pip install deepspeed==0.12.2
if [ "$installed_version" != "$latest_version" ]; then
pip install deepspeed==0.12.2
fi
# Install flash-attn if not installed
if ! pip show flash-attn > /dev/null 2>&1; then
pip install flash-attn --no-build-isolation
fi
################## VICUNA ##################
PROMPT_VERSION=v1
MODEL_VERSION="vicuna-7b-v1-5"
################## VICUNA ##################
################## project ##################
PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain"
################## data ##################
DATA_NAME="mixtral_instruct_158K_V1"
# wandb configure
export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953"
wandb login $WANDB_API_KEY
export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME
export WANDB_PROJECT=LLaVA_Mixtral
export WANDB_MODE=online
# wandb online
deepspeed --master_port 26000 \
llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/$DATA_NAME.json \
--image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_projector_type mlp2x_gelu \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 16 \
--lazy_preprocess True \
--report_to wandb
#!/bin/bash
dataset_name=$1
# Uncomment and set the following variables correspondingly to run this script:
cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA
# Install yolk3k if not installed
if ! pip show yolk3k > /dev/null 2>&1; then
pip install yolk3k
fi
# Get the installed version of transformers
installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2)
# Get the latest version of transformers from PyPI
latest_version=$(yolk -V transformers | cut -d ' ' -f 2)
# Check if the installed version is not the latest
if [ "$installed_version" != "$latest_version" ]; then
pip install -U transformers
fi
# Get the installed version of deepspeed
installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2)
# Get the latest version of deepspeed from PyPI
latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2)
# Check if the installed version is not the latest
if [ "$installed_version" != "$latest_version" ]; then
pip install deepspeed==0.12.2
fi
# Install yolk3k if not installed
if ! pip show flash-attn > /dev/null 2>&1; then
pip install flash-attn --no-build-isolation
fi
################## VICUNA ##################
PROMPT_VERSION=v1
MODEL_VERSION="vicuna-7b-v1-5"
################## VICUNA ##################
################## project ##################
PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain"
################## data ##################
DATA_NAME=$dataset_name
# wandb configure
export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953"
wandb login $WANDB_API_KEY
export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME
export WANDB_PROJECT=LLaVA_Mixtral
export WANDB_MODE=online
wandb online
deepspeed --master_port 26000 \
llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/$DATA_NAME.json \
--image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_projector_type mlp2x_gelu \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 16 \
--lazy_preprocess True \
--report_to wandb
#!/bin/bash
# Uncomment and set the following variables correspondingly to run this script:
################## VICUNA ##################
# PROMPT_VERSION=v1
# MODEL_VERSION="vicuna-v1-3-7b"
################## VICUNA ##################
################## LLaMA-2 ##################
# PROMPT_VERSION="llava_llama_2"
# MODEL_VERSION="llama-2-7b-chat"
################## LLaMA-2 ##################
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/llava_instruct_158k.json \
--image_folder /path/to/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 16 \
--lazy_preprocess True \
--report_to wandb
#!/bin/bash
# Uncomment and set the following variables correspondingly to run this script:
################## VICUNA ##################
# PROMPT_VERSION=v1
# MODEL_VERSION="vicuna-v1-3-7b"
################## VICUNA ##################
################## LLaMA-2 ##################
# PROMPT_VERSION="llava_llama_2"
# MODEL_VERSION="llama-2-7b-chat"
################## LLaMA-2 ##################
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--lora_enable True \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/llava_instruct_80k.json \
--image_folder /path/to/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--lazy_preprocess True \
--dataloader_num_workers 16 \
--report_to wandb
#!/bin/bash
cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA
# Install yolk3k if not installed
if ! pip show yolk3k > /dev/null 2>&1; then
pip install yolk3k
fi
# Get the installed version of transformers
installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2)
# Get the latest version of transformers from PyPI
latest_version=$(yolk -V transformers | cut -d ' ' -f 2)
# Check if the installed version is not the latest
if [ "$installed_version" != "$latest_version" ]; then
pip install -U transformers
fi
# Get the installed version of deepspeed
installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2)
# Get the latest version of deepspeed from PyPI
latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2)
# Check if the installed version is not the latest
if [ "$installed_version" != "$latest_version" ]; then
pip install deepspeed==0.12.2
fi
# Install yolk3k if not installed
if ! pip show flash-attn > /dev/null 2>&1; then
pip install flash-attn --no-build-isolation
fi
################## MISTRAL ##################
PROMPT_VERSION=mistral_instruct
MODEL_VERSION="Mistral-7B-Instruct-v0.2"
################## VICUNA ##################
################## project ##################
PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain"
################## data ##################
DATA_NAME="mixtral_instruct_158K_V1"
# wandb configure
export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953"
wandb login $WANDB_API_KEY
export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME
export WANDB_PROJECT=LLaVA_Mixtral
export WANDB_MODE=online
wandb online
deepspeed --master_port 26000 \
llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/$DATA_NAME.json \
--image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_projector_type mlp2x_gelu \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 16 \
--lazy_preprocess True \
--report_to wandb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment