\")]\n",
"\n",
"def collate_fn(examples):\n",
" texts = []\n",
" images = []\n",
" for example in examples:\n",
" image = example[\"image\"]\n",
" if image.mode != 'RGB':\n",
" image = image.convert('RGB')\n",
" question = example[\"question\"]\n",
" answer = example[\"multiple_choice_answer\"]\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"text\", \"text\": \"Answer briefly.\"},\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"text\", \"text\": question}\n",
" ]\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": [\n",
" {\"type\": \"text\", \"text\": answer}\n",
" ]\n",
" }\n",
" ]\n",
" text = processor.apply_chat_template(messages, add_generation_prompt=False)\n",
" texts.append(text.strip())\n",
" images.append([image])\n",
"\n",
" batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n",
" labels = batch[\"input_ids\"].clone()\n",
" labels[labels == processor.tokenizer.pad_token_id] = -100\n",
" labels[labels == image_token_id] = -100\n",
" batch[\"labels\"] = labels\n",
"\n",
" return batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kEYDjWpE3LD5"
},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QvAs896cdwg8"
},
"source": [
"We can now initialize `Trainer` and initialize `TrainingArguments` to pass to `Trainer`.\n",
"\n",
"Some notes:\n",
"- If you use 8-bit QLoRA with the below setup it uses around 16.4 GB VRAM (beautiful, fits comfortably inside L4, Colab free tier)\n",
"- We use gradient accumulation to simulate a larger batch size.\n",
"- We also save up on memory from intermediate activations by using gradient checkpointing.\n",
"\n",
"**Disclaimer:**\n",
"The techniques here aren't free lunch. The latter two will add additional compute to the training, thus slow down a bit (for reference on two A100s with bsz of 16, we were able to train for 2 hrs 43 mins with the gradient accumulation steps of 4, disabling it reduced it with 2 hr 35 mins).\n",
"If you want to speed-up, you might play around, reduce to 4-bit precision and have a higher batch size. Note that 4-bit might result in model learning less."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QNE2yWAYrAhD"
},
"outputs": [],
"source": [
"from transformers import TrainingArguments, Trainer\n",
"\n",
"model_name = model_id.split(\"/\")[-1]\n",
"\n",
"training_args = TrainingArguments(\n",
" num_train_epochs=1,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=50,\n",
" learning_rate=1e-4,\n",
" weight_decay=0.01,\n",
" logging_steps=25,\n",
" save_strategy=\"steps\",\n",
" save_steps=250,\n",
" save_total_limit=1,\n",
" optim=\"paged_adamw_8bit\", # for 8-bit, keep this, else adamw_hf\n",
" bf16=True, # underlying precision for 8bit\n",
" output_dir=f\"./{model_name}-vqav2\",\n",
" hub_model_id=f\"{model_name}-vqav2\",\n",
" report_to=\"tensorboard\",\n",
" remove_unused_columns=False,\n",
" gradient_checkpointing=True\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oBBSDpBhreJd",
"outputId": "071ed677-1d9f-4f98-9d19-64834440c9c4"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
]
}
],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=collate_fn,\n",
" train_dataset=train_ds,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_QOCpw_-uYYo",
"outputId": "7abb6937-c072-435a-c3f5-6dbb5b0b9eea"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [ 9/670 01:41 < 2:39:41, 0.07 it/s, Epoch 0.01/1]\n",
"
\n",
" \n",
" \n",
" \n",
" | Step | \n",
" Training Loss | \n",
"
\n",
" \n",
" \n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hN0QD9_uYYo"
},
"outputs": [],
"source": [
"trainer.push_to_hub()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {}
}
},
"nbformat": 4,
"nbformat_minor": 0
}