Unverified Commit c77c1e05 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

fix black in pre-commit (#1940)

parent dca87ec3
...@@ -30,6 +30,6 @@ repos: ...@@ -30,6 +30,6 @@ repos:
rev: 24.10.0 rev: 24.10.0
hooks: hooks:
- id: black - id: black
additional_dependencies: ['.[jupyter]'] types: [python]
types: [python, jupyter] - id: black-jupyter
types_or: [python, jupyter] types: [jupyter]
...@@ -34,10 +34,10 @@ ...@@ -34,10 +34,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:08.536886Z", "iopub.execute_input": "2024-11-07T18:44:42.063503Z",
"iopub.status.busy": "2024-11-05T05:08:08.536763Z", "iopub.status.busy": "2024-11-07T18:44:42.063379Z",
"iopub.status.idle": "2024-11-05T05:08:34.725831Z", "iopub.status.idle": "2024-11-07T18:45:07.255300Z",
"shell.execute_reply": "2024-11-05T05:08:34.725316Z" "shell.execute_reply": "2024-11-07T18:45:07.254547Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -73,10 +73,10 @@ ...@@ -73,10 +73,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:34.727530Z", "iopub.execute_input": "2024-11-07T18:45:07.258292Z",
"iopub.status.busy": "2024-11-05T05:08:34.727333Z", "iopub.status.busy": "2024-11-07T18:45:07.257710Z",
"iopub.status.idle": "2024-11-05T05:08:35.359784Z", "iopub.status.idle": "2024-11-07T18:45:07.611559Z",
"shell.execute_reply": "2024-11-05T05:08:35.359090Z" "shell.execute_reply": "2024-11-07T18:45:07.610842Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -101,10 +101,10 @@ ...@@ -101,10 +101,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.362286Z", "iopub.execute_input": "2024-11-07T18:45:07.613911Z",
"iopub.status.busy": "2024-11-05T05:08:35.362140Z", "iopub.status.busy": "2024-11-07T18:45:07.613746Z",
"iopub.status.idle": "2024-11-05T05:08:35.368711Z", "iopub.status.idle": "2024-11-07T18:45:07.620286Z",
"shell.execute_reply": "2024-11-05T05:08:35.368220Z" "shell.execute_reply": "2024-11-07T18:45:07.619779Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -132,10 +132,10 @@ ...@@ -132,10 +132,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.371313Z", "iopub.execute_input": "2024-11-07T18:45:07.622407Z",
"iopub.status.busy": "2024-11-05T05:08:35.370877Z", "iopub.status.busy": "2024-11-07T18:45:07.622267Z",
"iopub.status.idle": "2024-11-05T05:08:35.376712Z", "iopub.status.idle": "2024-11-07T18:45:07.628290Z",
"shell.execute_reply": "2024-11-05T05:08:35.376230Z" "shell.execute_reply": "2024-11-07T18:45:07.627793Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -164,10 +164,10 @@ ...@@ -164,10 +164,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.378982Z", "iopub.execute_input": "2024-11-07T18:45:07.630585Z",
"iopub.status.busy": "2024-11-05T05:08:35.378597Z", "iopub.status.busy": "2024-11-07T18:45:07.630235Z",
"iopub.status.idle": "2024-11-05T05:08:35.391820Z", "iopub.status.idle": "2024-11-07T18:45:07.643498Z",
"shell.execute_reply": "2024-11-05T05:08:35.391336Z" "shell.execute_reply": "2024-11-07T18:45:07.643007Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -183,10 +183,10 @@ ...@@ -183,10 +183,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.393748Z", "iopub.execute_input": "2024-11-07T18:45:07.645336Z",
"iopub.status.busy": "2024-11-05T05:08:35.393606Z", "iopub.status.busy": "2024-11-07T18:45:07.645196Z",
"iopub.status.idle": "2024-11-05T05:08:35.398645Z", "iopub.status.idle": "2024-11-07T18:45:07.650363Z",
"shell.execute_reply": "2024-11-05T05:08:35.398145Z" "shell.execute_reply": "2024-11-07T18:45:07.649837Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -211,10 +211,10 @@ ...@@ -211,10 +211,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.400683Z", "iopub.execute_input": "2024-11-07T18:45:07.652212Z",
"iopub.status.busy": "2024-11-05T05:08:35.400419Z", "iopub.status.busy": "2024-11-07T18:45:07.652076Z",
"iopub.status.idle": "2024-11-05T05:08:35.406146Z", "iopub.status.idle": "2024-11-07T18:45:07.658633Z",
"shell.execute_reply": "2024-11-05T05:08:35.405661Z" "shell.execute_reply": "2024-11-07T18:45:07.658119Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -241,10 +241,10 @@ ...@@ -241,10 +241,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.408176Z", "iopub.execute_input": "2024-11-07T18:45:07.660468Z",
"iopub.status.busy": "2024-11-05T05:08:35.407884Z", "iopub.status.busy": "2024-11-07T18:45:07.660325Z",
"iopub.status.idle": "2024-11-05T05:08:35.413587Z", "iopub.status.idle": "2024-11-07T18:45:07.666476Z",
"shell.execute_reply": "2024-11-05T05:08:35.413108Z" "shell.execute_reply": "2024-11-07T18:45:07.665984Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -271,10 +271,10 @@ ...@@ -271,10 +271,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:35.416090Z", "iopub.execute_input": "2024-11-07T18:45:07.668242Z",
"iopub.status.busy": "2024-11-05T05:08:35.415793Z", "iopub.status.busy": "2024-11-07T18:45:07.668108Z",
"iopub.status.idle": "2024-11-05T05:08:36.552549Z", "iopub.status.idle": "2024-11-07T18:45:08.725709Z",
"shell.execute_reply": "2024-11-05T05:08:36.551870Z" "shell.execute_reply": "2024-11-07T18:45:08.725021Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -296,10 +296,10 @@ ...@@ -296,10 +296,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:36.554823Z", "iopub.execute_input": "2024-11-07T18:45:08.727865Z",
"iopub.status.busy": "2024-11-05T05:08:36.554680Z", "iopub.status.busy": "2024-11-07T18:45:08.727721Z",
"iopub.status.idle": "2024-11-05T05:08:38.053945Z", "iopub.status.idle": "2024-11-07T18:45:11.165841Z",
"shell.execute_reply": "2024-11-05T05:08:38.053034Z" "shell.execute_reply": "2024-11-07T18:45:11.165282Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -335,10 +335,10 @@ ...@@ -335,10 +335,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:08:38.056783Z", "iopub.execute_input": "2024-11-07T18:45:11.167853Z",
"iopub.status.busy": "2024-11-05T05:08:38.056497Z", "iopub.status.busy": "2024-11-07T18:45:11.167711Z",
"iopub.status.idle": "2024-11-05T05:09:04.436030Z", "iopub.status.idle": "2024-11-07T18:45:39.542988Z",
"shell.execute_reply": "2024-11-05T05:09:04.435311Z" "shell.execute_reply": "2024-11-07T18:45:39.542135Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -360,10 +360,10 @@ ...@@ -360,10 +360,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:04.438987Z", "iopub.execute_input": "2024-11-07T18:45:39.545416Z",
"iopub.status.busy": "2024-11-05T05:09:04.438568Z", "iopub.status.busy": "2024-11-07T18:45:39.545005Z",
"iopub.status.idle": "2024-11-05T05:09:04.485291Z", "iopub.status.idle": "2024-11-07T18:45:39.588793Z",
"shell.execute_reply": "2024-11-05T05:09:04.484829Z" "shell.execute_reply": "2024-11-07T18:45:39.588054Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -392,10 +392,10 @@ ...@@ -392,10 +392,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:04.487191Z", "iopub.execute_input": "2024-11-07T18:45:39.590729Z",
"iopub.status.busy": "2024-11-05T05:09:04.486929Z", "iopub.status.busy": "2024-11-07T18:45:39.590446Z",
"iopub.status.idle": "2024-11-05T05:09:25.553481Z", "iopub.status.idle": "2024-11-07T18:45:59.660376Z",
"shell.execute_reply": "2024-11-05T05:09:25.552747Z" "shell.execute_reply": "2024-11-07T18:45:59.659992Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -419,10 +419,10 @@ ...@@ -419,10 +419,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:25.555813Z", "iopub.execute_input": "2024-11-07T18:45:59.661779Z",
"iopub.status.busy": "2024-11-05T05:09:25.555666Z", "iopub.status.busy": "2024-11-07T18:45:59.661641Z",
"iopub.status.idle": "2024-11-05T05:09:26.354372Z", "iopub.status.idle": "2024-11-07T18:46:00.475726Z",
"shell.execute_reply": "2024-11-05T05:09:26.353693Z" "shell.execute_reply": "2024-11-07T18:46:00.475269Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -445,10 +445,7 @@ ...@@ -445,10 +445,7 @@
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n", "prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n", "\n",
"url = \"http://localhost:30030/classify\"\n", "url = \"http://localhost:30030/classify\"\n",
"data = {\n", "data = {\"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \"text\": prompts}\n",
" \"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \n",
" \"text\": prompts\n",
"}\n",
"\n", "\n",
"responses = requests.post(url, json=data).json()\n", "responses = requests.post(url, json=data).json()\n",
"for response in responses:\n", "for response in responses:\n",
...@@ -460,10 +457,10 @@ ...@@ -460,10 +457,10 @@
"execution_count": 15, "execution_count": 15,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:26.356532Z", "iopub.execute_input": "2024-11-07T18:46:00.477283Z",
"iopub.status.busy": "2024-11-05T05:09:26.356327Z", "iopub.status.busy": "2024-11-07T18:46:00.477025Z",
"iopub.status.idle": "2024-11-05T05:09:26.396590Z", "iopub.status.idle": "2024-11-07T18:46:00.525758Z",
"shell.execute_reply": "2024-11-05T05:09:26.395914Z" "shell.execute_reply": "2024-11-07T18:46:00.525236Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -35,10 +35,10 @@ ...@@ -35,10 +35,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:27.503026Z", "iopub.execute_input": "2024-11-07T18:46:04.789536Z",
"iopub.status.busy": "2024-11-05T05:21:27.502741Z", "iopub.status.busy": "2024-11-07T18:46:04.789418Z",
"iopub.status.idle": "2024-11-05T05:21:49.554631Z", "iopub.status.idle": "2024-11-07T18:46:27.038169Z",
"shell.execute_reply": "2024-11-05T05:21:49.553690Z" "shell.execute_reply": "2024-11-07T18:46:27.037540Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -64,10 +64,10 @@ ...@@ -64,10 +64,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:49.558275Z", "iopub.execute_input": "2024-11-07T18:46:27.040005Z",
"iopub.status.busy": "2024-11-05T05:21:49.558110Z", "iopub.status.busy": "2024-11-07T18:46:27.039872Z",
"iopub.status.idle": "2024-11-05T05:21:52.717287Z", "iopub.status.idle": "2024-11-07T18:46:30.203840Z",
"shell.execute_reply": "2024-11-05T05:21:52.716842Z" "shell.execute_reply": "2024-11-07T18:46:30.203368Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -99,10 +99,10 @@ ...@@ -99,10 +99,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:21:52.721738Z", "iopub.execute_input": "2024-11-07T18:46:30.205880Z",
"iopub.status.busy": "2024-11-05T05:21:52.720908Z", "iopub.status.busy": "2024-11-07T18:46:30.205719Z",
"iopub.status.idle": "2024-11-05T05:22:01.770341Z", "iopub.status.idle": "2024-11-07T18:46:39.256561Z",
"shell.execute_reply": "2024-11-05T05:22:01.769510Z" "shell.execute_reply": "2024-11-07T18:46:39.255880Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -137,10 +137,10 @@ ...@@ -137,10 +137,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:01.772662Z", "iopub.execute_input": "2024-11-07T18:46:39.259464Z",
"iopub.status.busy": "2024-11-05T05:22:01.772377Z", "iopub.status.busy": "2024-11-07T18:46:39.259309Z",
"iopub.status.idle": "2024-11-05T05:22:04.897499Z", "iopub.status.idle": "2024-11-07T18:46:42.384955Z",
"shell.execute_reply": "2024-11-05T05:22:04.896867Z" "shell.execute_reply": "2024-11-07T18:46:42.384378Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -179,10 +179,10 @@ ...@@ -179,10 +179,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:04.899754Z", "iopub.execute_input": "2024-11-07T18:46:42.387431Z",
"iopub.status.busy": "2024-11-05T05:22:04.899478Z", "iopub.status.busy": "2024-11-07T18:46:42.387279Z",
"iopub.status.idle": "2024-11-05T05:22:13.970245Z", "iopub.status.idle": "2024-11-07T18:46:51.448572Z",
"shell.execute_reply": "2024-11-05T05:22:13.969779Z" "shell.execute_reply": "2024-11-07T18:46:51.447781Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -216,10 +216,10 @@ ...@@ -216,10 +216,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:13.972039Z", "iopub.execute_input": "2024-11-07T18:46:51.451177Z",
"iopub.status.busy": "2024-11-05T05:22:13.971846Z", "iopub.status.busy": "2024-11-07T18:46:51.450952Z",
"iopub.status.idle": "2024-11-05T05:22:14.027421Z", "iopub.status.idle": "2024-11-07T18:46:51.497530Z",
"shell.execute_reply": "2024-11-05T05:22:14.027003Z" "shell.execute_reply": "2024-11-07T18:46:51.496850Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -39,10 +39,10 @@ ...@@ -39,10 +39,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:30.637832Z", "iopub.execute_input": "2024-11-07T18:46:54.813876Z",
"iopub.status.busy": "2024-11-05T05:09:30.637709Z", "iopub.status.busy": "2024-11-07T18:46:54.813741Z",
"iopub.status.idle": "2024-11-05T05:09:58.830158Z", "iopub.status.idle": "2024-11-07T18:47:24.015527Z",
"shell.execute_reply": "2024-11-05T05:09:58.829395Z" "shell.execute_reply": "2024-11-07T18:47:24.014987Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -79,10 +79,10 @@ ...@@ -79,10 +79,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:09:58.833008Z", "iopub.execute_input": "2024-11-07T18:47:24.018153Z",
"iopub.status.busy": "2024-11-05T05:09:58.832805Z", "iopub.status.busy": "2024-11-07T18:47:24.017755Z",
"iopub.status.idle": "2024-11-05T05:10:00.187146Z", "iopub.status.idle": "2024-11-07T18:47:25.374821Z",
"shell.execute_reply": "2024-11-05T05:10:00.186657Z" "shell.execute_reply": "2024-11-07T18:47:25.374397Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -119,10 +119,10 @@ ...@@ -119,10 +119,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:00.189444Z", "iopub.execute_input": "2024-11-07T18:47:25.376617Z",
"iopub.status.busy": "2024-11-05T05:10:00.189289Z", "iopub.status.busy": "2024-11-07T18:47:25.376495Z",
"iopub.status.idle": "2024-11-05T05:10:03.291891Z", "iopub.status.idle": "2024-11-07T18:47:28.482537Z",
"shell.execute_reply": "2024-11-05T05:10:03.291173Z" "shell.execute_reply": "2024-11-07T18:47:28.482125Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -165,10 +165,10 @@ ...@@ -165,10 +165,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:03.294389Z", "iopub.execute_input": "2024-11-07T18:47:28.484819Z",
"iopub.status.busy": "2024-11-05T05:10:03.294237Z", "iopub.status.busy": "2024-11-07T18:47:28.484673Z",
"iopub.status.idle": "2024-11-05T05:10:03.469357Z", "iopub.status.idle": "2024-11-07T18:47:28.659814Z",
"shell.execute_reply": "2024-11-05T05:10:03.468661Z" "shell.execute_reply": "2024-11-07T18:47:28.659435Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -198,10 +198,10 @@ ...@@ -198,10 +198,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:03.471573Z", "iopub.execute_input": "2024-11-07T18:47:28.661844Z",
"iopub.status.busy": "2024-11-05T05:10:03.471430Z", "iopub.status.busy": "2024-11-07T18:47:28.661710Z",
"iopub.status.idle": "2024-11-05T05:10:04.977081Z", "iopub.status.idle": "2024-11-07T18:47:30.168922Z",
"shell.execute_reply": "2024-11-05T05:10:04.976391Z" "shell.execute_reply": "2024-11-07T18:47:30.168600Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -234,10 +234,10 @@ ...@@ -234,10 +234,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:04.979428Z", "iopub.execute_input": "2024-11-07T18:47:30.171319Z",
"iopub.status.busy": "2024-11-05T05:10:04.979272Z", "iopub.status.busy": "2024-11-07T18:47:30.171176Z",
"iopub.status.idle": "2024-11-05T05:10:08.568761Z", "iopub.status.idle": "2024-11-07T18:47:33.760113Z",
"shell.execute_reply": "2024-11-05T05:10:08.568355Z" "shell.execute_reply": "2024-11-07T18:47:33.759713Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -273,10 +273,10 @@ ...@@ -273,10 +273,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:08.571102Z", "iopub.execute_input": "2024-11-07T18:47:33.762729Z",
"iopub.status.busy": "2024-11-05T05:10:08.570964Z", "iopub.status.busy": "2024-11-07T18:47:33.762590Z",
"iopub.status.idle": "2024-11-05T05:10:23.214087Z", "iopub.status.idle": "2024-11-07T18:47:34.255316Z",
"shell.execute_reply": "2024-11-05T05:10:23.213664Z" "shell.execute_reply": "2024-11-07T18:47:34.254907Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -297,7 +297,10 @@ ...@@ -297,7 +297,10 @@
"response = client.chat.completions.create(\n", "response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"user\", \"content\": \"Give me the information of the capital of France in the JSON format.\"},\n", " {\n",
" \"role\": \"user\",\n",
" \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n", " ],\n",
" temperature=0,\n", " temperature=0,\n",
" max_tokens=128,\n", " max_tokens=128,\n",
...@@ -322,10 +325,10 @@ ...@@ -322,10 +325,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.216229Z", "iopub.execute_input": "2024-11-07T18:47:34.257393Z",
"iopub.status.busy": "2024-11-05T05:10:23.216076Z", "iopub.status.busy": "2024-11-07T18:47:34.257246Z",
"iopub.status.idle": "2024-11-05T05:10:23.884236Z", "iopub.status.idle": "2024-11-07T18:47:34.413506Z",
"shell.execute_reply": "2024-11-05T05:10:23.883897Z" "shell.execute_reply": "2024-11-07T18:47:34.413172Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -365,10 +368,10 @@ ...@@ -365,10 +368,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.886276Z", "iopub.execute_input": "2024-11-07T18:47:34.414816Z",
"iopub.status.busy": "2024-11-05T05:10:23.886136Z", "iopub.status.busy": "2024-11-07T18:47:34.414541Z",
"iopub.status.idle": "2024-11-05T05:10:23.905880Z", "iopub.status.idle": "2024-11-07T18:47:34.431341Z",
"shell.execute_reply": "2024-11-05T05:10:23.905529Z" "shell.execute_reply": "2024-11-07T18:47:34.431081Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -427,10 +430,10 @@ ...@@ -427,10 +430,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:23.907468Z", "iopub.execute_input": "2024-11-07T18:47:34.432325Z",
"iopub.status.busy": "2024-11-05T05:10:23.907247Z", "iopub.status.busy": "2024-11-07T18:47:34.432208Z",
"iopub.status.idle": "2024-11-05T05:10:26.920212Z", "iopub.status.idle": "2024-11-07T18:47:37.444337Z",
"shell.execute_reply": "2024-11-05T05:10:26.919865Z" "shell.execute_reply": "2024-11-07T18:47:37.444000Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -482,10 +485,10 @@ ...@@ -482,10 +485,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:26.922675Z", "iopub.execute_input": "2024-11-07T18:47:37.445894Z",
"iopub.status.busy": "2024-11-05T05:10:26.922413Z", "iopub.status.busy": "2024-11-07T18:47:37.445744Z",
"iopub.status.idle": "2024-11-05T05:10:51.961703Z", "iopub.status.idle": "2024-11-07T18:48:02.482532Z",
"shell.execute_reply": "2024-11-05T05:10:51.960846Z" "shell.execute_reply": "2024-11-07T18:48:02.482042Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -565,10 +568,10 @@ ...@@ -565,10 +568,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:10:51.964749Z", "iopub.execute_input": "2024-11-07T18:48:02.485206Z",
"iopub.status.busy": "2024-11-05T05:10:51.964215Z", "iopub.status.busy": "2024-11-07T18:48:02.485064Z",
"iopub.status.idle": "2024-11-05T05:11:05.023450Z", "iopub.status.idle": "2024-11-07T18:48:15.521489Z",
"shell.execute_reply": "2024-11-05T05:11:05.023101Z" "shell.execute_reply": "2024-11-07T18:48:15.521156Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -660,10 +663,10 @@ ...@@ -660,10 +663,10 @@
"execution_count": 13, "execution_count": 13,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:05.024877Z", "iopub.execute_input": "2024-11-07T18:48:15.522794Z",
"iopub.status.busy": "2024-11-05T05:11:05.024561Z", "iopub.status.busy": "2024-11-07T18:48:15.522657Z",
"iopub.status.idle": "2024-11-05T05:11:06.358695Z", "iopub.status.idle": "2024-11-07T18:48:16.875740Z",
"shell.execute_reply": "2024-11-05T05:11:06.357635Z" "shell.execute_reply": "2024-11-07T18:48:16.874847Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -35,10 +35,10 @@ ...@@ -35,10 +35,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:17.227174Z", "iopub.execute_input": "2024-11-07T18:48:21.128020Z",
"iopub.status.busy": "2024-11-05T05:22:17.226952Z", "iopub.status.busy": "2024-11-07T18:48:21.127898Z",
"iopub.status.idle": "2024-11-05T05:22:42.445791Z", "iopub.status.idle": "2024-11-07T18:48:45.310371Z",
"shell.execute_reply": "2024-11-05T05:22:42.444980Z" "shell.execute_reply": "2024-11-07T18:48:45.309469Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -72,10 +72,10 @@ ...@@ -72,10 +72,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.448147Z", "iopub.execute_input": "2024-11-07T18:48:45.313506Z",
"iopub.status.busy": "2024-11-05T05:22:42.447775Z", "iopub.status.busy": "2024-11-07T18:48:45.313123Z",
"iopub.status.idle": "2024-11-05T05:22:42.495311Z", "iopub.status.idle": "2024-11-07T18:48:45.364918Z",
"shell.execute_reply": "2024-11-05T05:22:42.495027Z" "shell.execute_reply": "2024-11-07T18:48:45.364155Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -106,10 +106,10 @@ ...@@ -106,10 +106,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.496666Z", "iopub.execute_input": "2024-11-07T18:48:45.367776Z",
"iopub.status.busy": "2024-11-05T05:22:42.496524Z", "iopub.status.busy": "2024-11-07T18:48:45.367490Z",
"iopub.status.idle": "2024-11-05T05:22:42.540687Z", "iopub.status.idle": "2024-11-07T18:48:45.411386Z",
"shell.execute_reply": "2024-11-05T05:22:42.540060Z" "shell.execute_reply": "2024-11-07T18:48:45.411134Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -140,10 +140,10 @@ ...@@ -140,10 +140,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.542551Z", "iopub.execute_input": "2024-11-07T18:48:45.412462Z",
"iopub.status.busy": "2024-11-05T05:22:42.542282Z", "iopub.status.busy": "2024-11-07T18:48:45.412351Z",
"iopub.status.idle": "2024-11-05T05:22:42.928542Z", "iopub.status.idle": "2024-11-07T18:48:45.768796Z",
"shell.execute_reply": "2024-11-05T05:22:42.928181Z" "shell.execute_reply": "2024-11-07T18:48:45.768406Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -176,10 +176,10 @@ ...@@ -176,10 +176,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:42.930093Z", "iopub.execute_input": "2024-11-07T18:48:45.770227Z",
"iopub.status.busy": "2024-11-05T05:22:42.929954Z", "iopub.status.busy": "2024-11-07T18:48:45.770106Z",
"iopub.status.idle": "2024-11-05T05:22:44.799945Z", "iopub.status.idle": "2024-11-07T18:48:47.447065Z",
"shell.execute_reply": "2024-11-05T05:22:44.799562Z" "shell.execute_reply": "2024-11-07T18:48:47.446733Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -208,10 +208,10 @@ ...@@ -208,10 +208,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:44.801418Z", "iopub.execute_input": "2024-11-07T18:48:47.448510Z",
"iopub.status.busy": "2024-11-05T05:22:44.801192Z", "iopub.status.busy": "2024-11-07T18:48:47.448337Z",
"iopub.status.idle": "2024-11-05T05:22:45.094634Z", "iopub.status.idle": "2024-11-07T18:48:47.743336Z",
"shell.execute_reply": "2024-11-05T05:22:45.093950Z" "shell.execute_reply": "2024-11-07T18:48:47.742276Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -39,10 +39,10 @@ ...@@ -39,10 +39,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:22:49.320999Z", "iopub.execute_input": "2024-11-07T18:43:47.311708Z",
"iopub.status.busy": "2024-11-05T05:22:49.320880Z", "iopub.status.busy": "2024-11-07T18:43:47.311517Z",
"iopub.status.idle": "2024-11-05T05:23:21.537478Z", "iopub.status.idle": "2024-11-07T18:44:18.512576Z",
"shell.execute_reply": "2024-11-05T05:23:21.536956Z" "shell.execute_reply": "2024-11-07T18:44:18.511909Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -78,10 +78,10 @@ ...@@ -78,10 +78,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:21.539953Z", "iopub.execute_input": "2024-11-07T18:44:18.515678Z",
"iopub.status.busy": "2024-11-05T05:23:21.539100Z", "iopub.status.busy": "2024-11-07T18:44:18.515314Z",
"iopub.status.idle": "2024-11-05T05:23:25.880179Z", "iopub.status.idle": "2024-11-07T18:44:22.880793Z",
"shell.execute_reply": "2024-11-05T05:23:25.879744Z" "shell.execute_reply": "2024-11-07T18:44:22.880303Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -129,10 +129,10 @@ ...@@ -129,10 +129,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:25.881742Z", "iopub.execute_input": "2024-11-07T18:44:22.883309Z",
"iopub.status.busy": "2024-11-05T05:23:25.881595Z", "iopub.status.busy": "2024-11-07T18:44:22.883160Z",
"iopub.status.idle": "2024-11-05T05:23:26.758503Z", "iopub.status.idle": "2024-11-07T18:44:27.048810Z",
"shell.execute_reply": "2024-11-05T05:23:26.758084Z" "shell.execute_reply": "2024-11-07T18:44:27.048074Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -176,10 +176,10 @@ ...@@ -176,10 +176,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:26.760098Z", "iopub.execute_input": "2024-11-07T18:44:27.051312Z",
"iopub.status.busy": "2024-11-05T05:23:26.759955Z", "iopub.status.busy": "2024-11-07T18:44:27.051190Z",
"iopub.status.idle": "2024-11-05T05:23:27.849510Z", "iopub.status.idle": "2024-11-07T18:44:32.358097Z",
"shell.execute_reply": "2024-11-05T05:23:27.849117Z" "shell.execute_reply": "2024-11-07T18:44:32.357628Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -227,10 +227,10 @@ ...@@ -227,10 +227,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:27.850994Z", "iopub.execute_input": "2024-11-07T18:44:32.359532Z",
"iopub.status.busy": "2024-11-05T05:23:27.850864Z", "iopub.status.busy": "2024-11-07T18:44:32.359413Z",
"iopub.status.idle": "2024-11-05T05:23:31.609137Z", "iopub.status.idle": "2024-11-07T18:44:36.164664Z",
"shell.execute_reply": "2024-11-05T05:23:31.608748Z" "shell.execute_reply": "2024-11-07T18:44:36.164005Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -276,10 +276,10 @@ ...@@ -276,10 +276,10 @@
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:23:31.610683Z", "iopub.execute_input": "2024-11-07T18:44:36.167123Z",
"iopub.status.busy": "2024-11-05T05:23:31.610560Z", "iopub.status.busy": "2024-11-07T18:44:36.166535Z",
"iopub.status.idle": "2024-11-05T05:23:32.965146Z", "iopub.status.idle": "2024-11-07T18:44:37.743761Z",
"shell.execute_reply": "2024-11-05T05:23:32.963922Z" "shell.execute_reply": "2024-11-07T18:44:37.742510Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -31,7 +31,7 @@ extensions = [ ...@@ -31,7 +31,7 @@ extensions = [
] ]
nbsphinx_allow_errors = True nbsphinx_allow_errors = True
nbsphinx_execute = 'never' nbsphinx_execute = "never"
autosectionlabel_prefix_document = True autosectionlabel_prefix_document = True
nbsphinx_allow_directives = True nbsphinx_allow_directives = True
...@@ -49,7 +49,7 @@ myst_enable_extensions = [ ...@@ -49,7 +49,7 @@ myst_enable_extensions = [
myst_heading_anchors = 3 myst_heading_anchors = 3
nbsphinx_kernel_name = 'python3' nbsphinx_kernel_name = "python3"
nbsphinx_execute_arguments = [ nbsphinx_execute_arguments = [
"--InlineBackend.figure_formats={'svg', 'pdf'}", "--InlineBackend.figure_formats={'svg', 'pdf'}",
"--InlineBackend.rc={'figure.dpi': 96}", "--InlineBackend.rc={'figure.dpi': 96}",
...@@ -130,8 +130,10 @@ html_context = { ...@@ -130,8 +130,10 @@ html_context = {
html_static_path = ["_static"] html_static_path = ["_static"]
html_css_files = ["css/custom_log.css"] html_css_files = ["css/custom_log.css"]
def setup(app): def setup(app):
app.add_css_file('css/custom_log.css') app.add_css_file("css/custom_log.css")
myst_enable_extensions = [ myst_enable_extensions = [
"dollarmath", "dollarmath",
......
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:10.680191Z", "iopub.execute_input": "2024-11-07T18:48:52.032229Z",
"iopub.status.busy": "2024-11-05T05:11:10.679710Z", "iopub.status.busy": "2024-11-07T18:48:52.032105Z",
"iopub.status.idle": "2024-11-05T05:11:39.882385Z", "iopub.status.idle": "2024-11-07T18:49:20.226042Z",
"shell.execute_reply": "2024-11-05T05:11:39.881827Z" "shell.execute_reply": "2024-11-07T18:49:20.225562Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -49,7 +49,7 @@ ...@@ -49,7 +49,7 @@
")\n", ")\n",
"\n", "\n",
"server_process = execute_shell_command(\n", "server_process = execute_shell_command(\n",
"\"\"\"\n", " \"\"\"\n",
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
"--port 30000 --host 0.0.0.0\n", "--port 30000 --host 0.0.0.0\n",
"\"\"\"\n", "\"\"\"\n",
...@@ -70,10 +70,10 @@ ...@@ -70,10 +70,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:39.883923Z", "iopub.execute_input": "2024-11-07T18:49:20.228006Z",
"iopub.status.busy": "2024-11-05T05:11:39.883721Z", "iopub.status.busy": "2024-11-07T18:49:20.227572Z",
"iopub.status.idle": "2024-11-05T05:11:40.124980Z", "iopub.status.idle": "2024-11-07T18:49:20.469885Z",
"shell.execute_reply": "2024-11-05T05:11:40.124557Z" "shell.execute_reply": "2024-11-07T18:49:20.469518Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -101,10 +101,10 @@ ...@@ -101,10 +101,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:40.126564Z", "iopub.execute_input": "2024-11-07T18:49:20.471956Z",
"iopub.status.busy": "2024-11-05T05:11:40.126369Z", "iopub.status.busy": "2024-11-07T18:49:20.471811Z",
"iopub.status.idle": "2024-11-05T05:11:40.324316Z", "iopub.status.idle": "2024-11-07T18:49:20.667997Z",
"shell.execute_reply": "2024-11-05T05:11:40.323693Z" "shell.execute_reply": "2024-11-07T18:49:20.667630Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -115,9 +115,7 @@ ...@@ -115,9 +115,7 @@
"\n", "\n",
"data = {\n", "data = {\n",
" \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" \"messages\": [\n", " \"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}],\n",
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"}\n",
" ]\n",
"}\n", "}\n",
"\n", "\n",
"response = requests.post(url, json=data)\n", "response = requests.post(url, json=data)\n",
...@@ -136,10 +134,10 @@ ...@@ -136,10 +134,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:40.327043Z", "iopub.execute_input": "2024-11-07T18:49:20.669977Z",
"iopub.status.busy": "2024-11-05T05:11:40.326759Z", "iopub.status.busy": "2024-11-07T18:49:20.669826Z",
"iopub.status.idle": "2024-11-05T05:11:41.687336Z", "iopub.status.idle": "2024-11-07T18:49:22.004855Z",
"shell.execute_reply": "2024-11-05T05:11:41.686855Z" "shell.execute_reply": "2024-11-07T18:49:22.004472Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -171,10 +169,10 @@ ...@@ -171,10 +169,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:41.688676Z", "iopub.execute_input": "2024-11-07T18:49:22.006983Z",
"iopub.status.busy": "2024-11-05T05:11:41.688527Z", "iopub.status.busy": "2024-11-07T18:49:22.006858Z",
"iopub.status.idle": "2024-11-05T05:11:42.717140Z", "iopub.status.idle": "2024-11-07T18:49:23.029098Z",
"shell.execute_reply": "2024-11-05T05:11:42.716452Z" "shell.execute_reply": "2024-11-07T18:49:23.028697Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -197,7 +195,7 @@ ...@@ -197,7 +195,7 @@
"# Handle the streaming output\n", "# Handle the streaming output\n",
"for chunk in response:\n", "for chunk in response:\n",
" if chunk.choices[0].delta.content:\n", " if chunk.choices[0].delta.content:\n",
" print(chunk.choices[0].delta.content, end='', flush=True)" " print(chunk.choices[0].delta.content, end=\"\", flush=True)"
] ]
}, },
{ {
...@@ -214,10 +212,10 @@ ...@@ -214,10 +212,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:42.720467Z", "iopub.execute_input": "2024-11-07T18:49:23.031712Z",
"iopub.status.busy": "2024-11-05T05:11:42.720182Z", "iopub.status.busy": "2024-11-07T18:49:23.031571Z",
"iopub.status.idle": "2024-11-05T05:11:43.480765Z", "iopub.status.idle": "2024-11-07T18:49:23.787752Z",
"shell.execute_reply": "2024-11-05T05:11:43.480143Z" "shell.execute_reply": "2024-11-07T18:49:23.787368Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -250,10 +248,10 @@ ...@@ -250,10 +248,10 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:43.483575Z", "iopub.execute_input": "2024-11-07T18:49:23.789840Z",
"iopub.status.busy": "2024-11-05T05:11:43.483295Z", "iopub.status.busy": "2024-11-07T18:49:23.789702Z",
"iopub.status.idle": "2024-11-05T05:11:44.242950Z", "iopub.status.idle": "2024-11-07T18:49:24.545631Z",
"shell.execute_reply": "2024-11-05T05:11:44.242248Z" "shell.execute_reply": "2024-11-07T18:49:24.545241Z"
} }
}, },
"outputs": [], "outputs": [],
...@@ -290,10 +288,10 @@ ...@@ -290,10 +288,10 @@
"execution_count": 8, "execution_count": 8,
"metadata": { "metadata": {
"execution": { "execution": {
"iopub.execute_input": "2024-11-05T05:11:44.245660Z", "iopub.execute_input": "2024-11-07T18:49:24.547641Z",
"iopub.status.busy": "2024-11-05T05:11:44.245373Z", "iopub.status.busy": "2024-11-07T18:49:24.547497Z",
"iopub.status.idle": "2024-11-05T05:11:45.591682Z", "iopub.status.idle": "2024-11-07T18:49:25.888864Z",
"shell.execute_reply": "2024-11-05T05:11:45.591184Z" "shell.execute_reply": "2024-11-07T18:49:25.888114Z"
} }
}, },
"outputs": [], "outputs": [],
......
...@@ -71,7 +71,7 @@ ...@@ -71,7 +71,7 @@
"source": [ "source": [
"import json\n", "import json\n",
"import os\n", "import os\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"import chromadb\n", "import chromadb\n",
"\n", "\n",
...@@ -80,7 +80,7 @@ ...@@ -80,7 +80,7 @@
"if not os.path.exists(path_qca):\n", "if not os.path.exists(path_qca):\n",
" !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n", " !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
"\n", "\n",
"with open(path_qca, 'r') as f:\n", "with open(path_qca, \"r\") as f:\n",
" question_context_answers = json.load(f)\n", " question_context_answers = json.load(f)\n",
"\n", "\n",
"chroma_client = chromadb.PersistentClient()\n", "chroma_client = chromadb.PersistentClient()\n",
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
"if collection.count() == 0:\n", "if collection.count() == 0:\n",
" collection.add(\n", " collection.add(\n",
" documents=[qca[\"context\"] for qca in question_context_answers],\n", " documents=[qca[\"context\"] for qca in question_context_answers],\n",
" ids=[str(i) for i in range(len(question_context_answers))]\n", " ids=[str(i) for i in range(len(question_context_answers))],\n",
" )" " )"
], ],
"metadata": { "metadata": {
...@@ -123,7 +123,7 @@ ...@@ -123,7 +123,7 @@
"\n", "\n",
"load_dotenv()\n", "load_dotenv()\n",
"\n", "\n",
"os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n", "\n",
"p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n", "p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
"p.integrate_with_sglang()\n", "p.integrate_with_sglang()\n",
...@@ -150,10 +150,7 @@ ...@@ -150,10 +150,7 @@
"source": [ "source": [
"@trace\n", "@trace\n",
"def retrieval(question: str) -> List[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
...@@ -176,7 +173,9 @@ ...@@ -176,7 +173,9 @@
"@function\n", "@function\n",
"def generation_sglang(s, question: str, *context: str):\n", "def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n", " context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n", " s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\"))\n", " s += assistant(gen(\"answer\"))\n",
"\n", "\n",
"\n", "\n",
...@@ -223,7 +222,9 @@ ...@@ -223,7 +222,9 @@
" return generation(question, *contexts)\n", " return generation(question, *contexts)\n",
"\n", "\n",
"\n", "\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")" "rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
] ]
}, },
{ {
...@@ -271,7 +272,10 @@ ...@@ -271,7 +272,10 @@
"execution_count": null, "execution_count": null,
"outputs": [], "outputs": [],
"source": [ "source": [
"from parea.evals.rag import context_query_relevancy_factory, percent_target_supported_by_context_factory\n", "from parea.evals.rag import (\n",
" context_query_relevancy_factory,\n",
" percent_target_supported_by_context_factory,\n",
")\n",
"\n", "\n",
"\n", "\n",
"context_relevancy_eval = context_query_relevancy_factory()\n", "context_relevancy_eval = context_query_relevancy_factory()\n",
...@@ -280,10 +284,7 @@ ...@@ -280,10 +284,7 @@
"\n", "\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n", "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> List[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
...@@ -310,10 +311,13 @@ ...@@ -310,10 +311,13 @@
"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n", "answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n", "answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
"\n", "\n",
"\n",
"@function\n", "@function\n",
"def generation_sglang(s, question: str, *context: str):\n", "def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n", " context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n", " s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\", max_tokens=1_000))\n", " s += assistant(gen(\"answer\", max_tokens=1_000))\n",
"\n", "\n",
"\n", "\n",
...@@ -357,7 +361,9 @@ ...@@ -357,7 +361,9 @@
" return generation(question, *contexts)\n", " return generation(question, *contexts)\n",
"\n", "\n",
"\n", "\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")" "rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
...@@ -402,6 +408,7 @@ ...@@ -402,6 +408,7 @@
"source": [ "source": [
"!pip install nest-asyncio\n", "!pip install nest-asyncio\n",
"import nest_asyncio\n", "import nest_asyncio\n",
"\n",
"nest_asyncio.apply()" "nest_asyncio.apply()"
], ],
"metadata": { "metadata": {
...@@ -461,7 +468,7 @@ ...@@ -461,7 +468,7 @@
], ],
"source": [ "source": [
"e = p.experiment(\n", "e = p.experiment(\n",
" 'RAG',\n", " \"RAG\",\n",
" data=[\n", " data=[\n",
" {\n", " {\n",
" \"question\": qca[\"question\"],\n", " \"question\": qca[\"question\"],\n",
...@@ -469,7 +476,7 @@ ...@@ -469,7 +476,7 @@
" }\n", " }\n",
" for qca in question_context_answers\n", " for qca in question_context_answers\n",
" ],\n", " ],\n",
" func=rag_pipeline\n", " func=rag_pipeline,\n",
").run()" ").run()"
], ],
"metadata": { "metadata": {
......
...@@ -7,6 +7,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer ...@@ -7,6 +7,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
def main(): def main():
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
......
...@@ -39,7 +39,7 @@ class ModelConfig: ...@@ -39,7 +39,7 @@ class ModelConfig:
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None is_embedding: Optional[bool] = None,
) -> None: ) -> None:
# Parse args # Parse args
self.model_override_args = json.loads(model_override_args) self.model_override_args = json.loads(model_override_args)
...@@ -52,7 +52,9 @@ class ModelConfig: ...@@ -52,7 +52,9 @@ class ModelConfig:
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type # Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding) self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures) self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
......
...@@ -122,16 +122,14 @@ class QuantizationConfig(ABC): ...@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool: def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
""" """
Not all quant methods have embedding implemented, so we need to check that Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function it exists for our given method. We check this by making sure the function
has been changed from the base implementation. has been changed from the base implementation.
""" """
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
None)
class_embedding = inspect.getattr_static(method_class, "embedding", None) class_embedding = inspect.getattr_static(method_class, "embedding", None)
return (class_embedding is not None return class_embedding is not None and class_embedding is not base_embedding
and class_embedding is not base_embedding)
...@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64 ...@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase): class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings.""" """Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module, def create_weights(
input_size_per_partition: int, self,
output_partition_sizes: List[int], input_size: int, layer: torch.nn.Module,
output_size: int, params_dtype: torch.dtype, input_size_per_partition: int,
**extra_weight_attrs): output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer.""" """Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(
input_size_per_partition, torch.empty(
dtype=params_dtype), sum(output_partition_sizes),
requires_grad=False) input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight) return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int, def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size( def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, per_partition_vocab_size: int, rank: int, offset: int = 0
rank: int, ) -> Sequence[int]:
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, def vocab_range_from_global_vocab_size(
rank: int, global_vocab_size: int, rank: int, world_size: int, offset: int = 0
world_size: int, ) -> Sequence[int]:
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size) per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, return vocab_range_from_per_partition_vocab_size(
rank, per_partition_vocab_size, rank, offset=offset
offset=offset) )
@dataclass @dataclass
class VocabParallelEmbeddingShardIndices: class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding.""" """Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int padded_org_vocab_start_index: int
padded_org_vocab_end_index: int padded_org_vocab_end_index: int
padded_added_vocab_start_index: int padded_added_vocab_start_index: int
...@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices: ...@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
@property @property
def num_org_elements_padded(self) -> int: def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index - return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
self.padded_org_vocab_start_index)
@property @property
def num_added_elements_padded(self) -> int: def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index - return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
self.padded_added_vocab_start_index)
@property @property
def num_org_vocab_padding(self) -> int: def num_org_vocab_padding(self) -> int:
...@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices: ...@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
def __post_init__(self): def __post_init__(self):
# sanity checks # sanity checks
assert (self.padded_org_vocab_start_index <= assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
self.padded_org_vocab_end_index) assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert (self.padded_added_vocab_start_index <=
self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index <= assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
...@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices: ...@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
@torch.jit.script @torch.jit.script
def get_masked_input_and_mask( def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int, input_: torch.Tensor,
org_vocab_end_index: int, num_org_vocab_padding: int, org_vocab_start_index: int,
added_vocab_start_index: int, org_vocab_end_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below # torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast # into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & ( added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index) input_ < added_vocab_end_index
added_offset = added_vocab_start_index - ( )
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding added_offset = (
valid_offset = (org_vocab_start_index * added_vocab_start_index
org_vocab_mask) + (added_offset * added_vocab_mask) - (org_vocab_end_index - org_vocab_start_index)
- num_org_vocab_padding
)
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
added_offset * added_vocab_mask
)
vocab_mask = org_vocab_mask | added_vocab_mask vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset) input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask return input_, ~vocab_mask
...@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
prefix: full name of the layer in the state dict prefix: full name of the layer in the state dict
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(
num_embeddings: int, self,
embedding_dim: int, num_embeddings: int,
params_dtype: Optional[torch.dtype] = None, embedding_dim: int,
org_num_embeddings: Optional[int] = None, params_dtype: Optional[torch.dtype] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, org_num_embeddings: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
prefix: str = "", quant_config: Optional[QuantizationConfig] = None,
enable_tp: bool = True): prefix: str = "",
enable_tp: bool = True,
):
super().__init__() super().__init__()
self.enable_tp = enable_tp self.enable_tp = enable_tp
...@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
self.padding_size = padding_size self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.org_vocab_size_padded = pad_vocab_size(
self.padding_size) self.org_vocab_size, self.padding_size
)
self.num_embeddings_padded = pad_vocab_size( self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.org_vocab_size_padded + num_added_embeddings, self.padding_size
self.padding_size) )
assert self.org_vocab_size_padded <= self.num_embeddings_padded assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded, self.shard_indices = self._get_indices(
self.org_vocab_size_padded, self.num_embeddings_padded,
self.num_embeddings, self.org_vocab_size_padded,
self.org_vocab_size, tp_rank, self.num_embeddings,
self.tp_size) self.org_vocab_size,
tp_rank,
self.tp_size,
)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
linear_method = None linear_method = None
...@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
# layer type like ParallelLMHead, this is not important. # layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding( linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method)) type(linear_method)
)
if is_embedding_layer and not linear_method_implements_embedding: if is_embedding_layer and not linear_method_implements_embedding:
raise NotImplementedError( raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement " f"The class {type(linear_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.") "the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self.linear_method: QuantizeMethodBase = linear_method self.linear_method: QuantizeMethodBase = linear_method
...@@ -260,53 +278,68 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -260,53 +278,68 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.num_embeddings_per_partition = divide(
self.tp_size) self.num_embeddings_padded, self.tp_size
assert (self.shard_indices.num_elements_padded == )
self.num_embeddings_per_partition) assert (
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
)
self.num_org_embeddings_per_partition = ( self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index - self.shard_indices.org_vocab_end_index
self.shard_indices.org_vocab_start_index) - self.shard_indices.org_vocab_start_index
)
self.num_added_embeddings_per_partition = ( self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_end_index
self.shard_indices.added_vocab_start_index) - self.shard_indices.added_vocab_start_index
)
self.linear_method.create_weights(self,
self.embedding_dim, self.linear_method.create_weights(
[self.num_embeddings_per_partition], self,
self.embedding_dim, self.embedding_dim,
self.num_embeddings_padded, [self.num_embeddings_per_partition],
params_dtype=params_dtype, self.embedding_dim,
weight_loader=self.weight_loader) self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
@classmethod @classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, def _get_indices(
vocab_size: int, org_vocab_size: int, tp_rank: int, cls,
tp_size: int) -> VocabParallelEmbeddingShardIndices: vocab_size_padded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the """Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and layout outlined in the class docstring, based on the given tp_rank and
tp_size.""" tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = ( padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
tp_size)) )
padded_added_vocab_start_index, padded_added_vocab_end_index = ( padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded, vocab_range_from_global_vocab_size(
tp_rank, num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
tp_size, )
offset=org_vocab_size)) )
# remove padding # remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices( return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index, padded_org_vocab_start_index,
padded_added_vocab_start_index, padded_added_vocab_end_index, padded_org_vocab_end_index,
org_vocab_start_index, org_vocab_end_index, padded_added_vocab_start_index,
added_vocab_start_index, added_vocab_end_index) padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]: def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
"""Get a mapping that can be used to reindex the gathered """Get a mapping that can be used to reindex the gathered
...@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
added_embeddings: List[int] = [] added_embeddings: List[int] = []
padding: List[int] = [] padding: List[int] = []
for tp_rank in range(self.tp_size): for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded, shard_indices = self._get_indices(
self.org_vocab_size_padded, self.num_embeddings_padded,
self.num_embeddings, self.org_vocab_size_padded,
self.org_vocab_size, tp_rank, self.num_embeddings,
self.tp_size) self.org_vocab_size,
tp_rank,
self.tp_size,
)
range_start = self.num_embeddings_per_partition * tp_rank range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1) range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend( base_embeddings.extend(
range(range_start, range(range_start, range_start + shard_indices.num_org_elements)
range_start + shard_indices.num_org_elements)) )
padding.extend( padding.extend(
range(range_start + shard_indices.num_org_elements, range(
range_start + shard_indices.num_org_elements_padded)) range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded,
)
)
added_embeddings.extend( added_embeddings.extend(
range( range(
range_start + shard_indices.num_org_elements_padded, range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded + range_start
shard_indices.num_added_elements)) + shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements,
)
)
padding.extend( padding.extend(
range( range(
range_start + shard_indices.num_org_elements_padded + range_start
shard_indices.num_added_elements, + shard_indices.num_org_elements_padded
range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements,
shard_indices.num_added_elements_padded)) range_start
assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_org_elements_padded
shard_indices.num_added_elements_padded == range_end) + shard_indices.num_added_elements_padded,
)
)
assert (
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements_padded
== range_end
)
ret = base_embeddings + added_embeddings + padding ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded assert len(ret) == self.num_embeddings_padded
return ret return ret
...@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
# If param packed on the same dim we are sharding on, then # If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor. # need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim: if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance( packed_factor = (
param, BasevLLMParameter) else param.pack_factor param.packed_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // if isinstance(param, BasevLLMParameter)
param.packed_factor) else param.pack_factor
)
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size // param.packed_factor
)
start_idx = start_idx // packed_factor start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor shard_size = shard_size // packed_factor
else: else:
...@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. # Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0) param[loaded_weight.shape[0] :].data.fill_(0)
def forward(self, input_): def forward(self, input_):
if self.tp_size > 1: if self.tp_size > 1:
# Build the mask. # Build the mask.
masked_input, input_mask = get_masked_input_and_mask( masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index, input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index, self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding, self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index, self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index) self.shard_indices.added_vocab_end_index,
)
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = self.linear_method.embedding(self, output_parallel = self.linear_method.embedding(self, masked_input.long())
masked_input.long())
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
...@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
s = f"num_embeddings={self.num_embeddings_per_partition}" s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}" s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}" s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}' s += f", num_embeddings_padded={self.num_embeddings_padded}"
if self.enable_tp: if self.enable_tp:
s += f', tp_size={self.tp_size}' s += f", tp_size={self.tp_size}"
return s return s
...@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: padding size for the vocabulary. padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(
num_embeddings: int, self,
embedding_dim: int, num_embeddings: int,
bias: bool = False, embedding_dim: int,
params_dtype: Optional[torch.dtype] = None, bias: bool = False,
org_num_embeddings: Optional[int] = None, params_dtype: Optional[torch.dtype] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, org_num_embeddings: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
super().__init__(num_embeddings, embedding_dim, params_dtype, prefix: str = "",
org_num_embeddings, padding_size, quant_config, ):
prefix) super().__init__(
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
self.quant_config = quant_config self.quant_config = quant_config
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
dtype=params_dtype)) )
set_weight_attrs(self.bias, { set_weight_attrs(
"output_dim": 0, self.bias,
"weight_loader": self.weight_loader, {
}) "output_dim": 0,
"weight_loader": self.weight_loader,
},
)
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
......
...@@ -86,8 +86,10 @@ class GenerateReqInput: ...@@ -86,8 +86,10 @@ class GenerateReqInput:
self.parallel_sample_num = self.sampling_params.get("n", 1) self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list): else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1) self.parallel_sample_num = self.sampling_params[0].get("n", 1)
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), ( assert all(
"The parallel_sample_num should be the same for all samples in sample params.") self.parallel_sample_num == sampling_params.get("n", 1)
for sampling_params in self.sampling_params
), "The parallel_sample_num should be the same for all samples in sample params."
if self.parallel_sample_num > 1 and self.is_single: if self.parallel_sample_num > 1 and self.is_single:
self.is_single = False self.is_single = False
......
...@@ -911,8 +911,7 @@ class ScheduleBatch: ...@@ -911,8 +911,7 @@ class ScheduleBatch:
keep_indices = [ keep_indices = [
i i
for i in range(len(self.reqs)) for i in range(len(self.reqs))
if not self.reqs[i].finished() if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
and self.reqs[i] is not being_chunked_req
] ]
if keep_indices is None or len(keep_indices) == 0: if keep_indices is None or len(keep_indices) == 0:
...@@ -1043,6 +1042,7 @@ class ScheduleBatch: ...@@ -1043,6 +1042,7 @@ class ScheduleBatch:
for req in self.reqs: for req in self.reqs:
req.started_time = time.time() req.started_time = time.time()
@dataclasses.dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id # The batch id
......
...@@ -224,8 +224,8 @@ class Scheduler: ...@@ -224,8 +224,8 @@ class Scheduler:
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log self.last_log_tic = time.time() # time of last log for print decode log
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
# Init chunked prefill # Init chunked prefill
...@@ -566,9 +566,7 @@ class Scheduler: ...@@ -566,9 +566,7 @@ class Scheduler:
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.being_chunked_req: if self.being_chunked_req:
self.last_batch.filter_batch( self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
being_chunked_req=self.being_chunked_req
)
self.tree_cache.cache_unfinished_req(self.being_chunked_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx. # Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
...@@ -628,9 +626,7 @@ class Scheduler: ...@@ -628,9 +626,7 @@ class Scheduler:
has_inflight = self.being_chunked_req is not None has_inflight = self.being_chunked_req is not None
if has_inflight: if has_inflight:
self.being_chunked_req.init_next_round_input() self.being_chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_inflight_req( self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
self.being_chunked_req
)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = (
...@@ -813,7 +809,8 @@ class Scheduler: ...@@ -813,7 +809,8 @@ class Scheduler:
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid ret = embeddings, model_worker_batch.bid
return ret return ret
def get_stats(self,batch: ScheduleBatch):
def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill # TODO: get stats for chunked prefill
now = time.time() now = time.time()
...@@ -829,8 +826,8 @@ class Scheduler: ...@@ -829,8 +826,8 @@ class Scheduler:
# set stats from prefill # set stats from prefill
if self.stats is not None: if self.stats is not None:
# new_seq=self.stats.new_seq # new_seq=self.stats.new_seq
cache_hit_rate=self.stats.cache_hit_rate cache_hit_rate = self.stats.cache_hit_rate
token_usage=self.stats.token_usage token_usage = self.stats.token_usage
# Iteration stats # Iteration stats
num_prompt_tokens_iter = 0 num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0 num_generation_tokens_iter = 0
...@@ -851,15 +848,19 @@ class Scheduler: ...@@ -851,15 +848,19 @@ class Scheduler:
# _, next_token_ids, _ = result # _, next_token_ids, _ = result
if batch is not None: if batch is not None:
num_generation_tokens_iter = len(batch.output_ids) num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(num_generation_tokens_iter / (now - self.last_stats_tic), 2) gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode, # NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
num_prompt_tokens_iter=len(batch.input_ids)+sum(batch.prefix_lens) num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time) time_to_first_tokens_iter.append(now - req.started_time)
else: else:
time_per_output_tokens_iter.append(now-self.last_stats_tic) time_per_output_tokens_iter.append(now - self.last_stats_tic)
if req.finished(): if req.finished():
time_e2e_requests.append(now - req.created_time) time_e2e_requests.append(now - req.created_time)
...@@ -867,9 +868,10 @@ class Scheduler: ...@@ -867,9 +868,10 @@ class Scheduler:
num_prompt_tokens_requests.append(len(req.origin_input_ids)) num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids)) num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append( finished_reason_requests.append(
req.finished_reason.to_json() req.finished_reason.to_json()
if req.finished_reason is not None if req.finished_reason is not None
else None) else None
)
return Stats( return Stats(
new_seq=new_seq, new_seq=new_seq,
...@@ -893,7 +895,7 @@ class Scheduler: ...@@ -893,7 +895,7 @@ class Scheduler:
max_running_requests=self.max_running_requests, max_running_requests=self.max_running_requests,
) )
def log_stats(self,stats:Stats): def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats) self.metrics_collector.log_stats(stats)
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
...@@ -1003,9 +1005,7 @@ class Scheduler: ...@@ -1003,9 +1005,7 @@ class Scheduler:
if req.is_retracted: if req.is_retracted:
continue continue
if self.server_args.enable_overlap_schedule and ( if self.server_args.enable_overlap_schedule and (req.finished()):
req.finished()
):
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
...@@ -1031,7 +1031,10 @@ class Scheduler: ...@@ -1031,7 +1031,10 @@ class Scheduler:
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0: if (
self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.print_decode_stats() self.print_decode_stats()
def add_logprob_return_values( def add_logprob_return_values(
......
...@@ -215,7 +215,7 @@ class TokenizerManager: ...@@ -215,7 +215,7 @@ class TokenizerManager:
logprob_start_len, logprob_start_len,
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
obj.lora_path obj.lora_path,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -290,7 +290,9 @@ class TokenizerManager: ...@@ -290,7 +290,9 @@ class TokenizerManager:
# Tokenize all requests # Tokenize all requests
objs = [obj[i] for i in range(batch_size)] objs = [obj[i] for i in range(batch_size)]
tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs)) tokenized_objs = await asyncio.gather(
*(self._tokenize_one_request(obj) for obj in objs)
)
# Cache the common prefix for parallel sampling # Cache the common prefix for parallel sampling
for i in range(batch_size): for i in range(batch_size):
...@@ -322,7 +324,9 @@ class TokenizerManager: ...@@ -322,7 +324,9 @@ class TokenizerManager:
rid_to_index = {rid: i for i, rid in enumerate(rids)} rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map: while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done: for task in done:
gen = task_map.pop(task) gen = task_map.pop(task)
...@@ -367,7 +371,7 @@ class TokenizerManager: ...@@ -367,7 +371,7 @@ class TokenizerManager:
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
res = await self.mem_pool_size res = await self.mem_pool_size
return res.size return res.size
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.mem_pool_size_tmp = [] self.mem_pool_size_tmp = []
res = await self.mem_pool_size res = await self.mem_pool_size
ret = [r.size for r in res] ret = [r.size for r in res]
...@@ -399,7 +403,7 @@ class TokenizerManager: ...@@ -399,7 +403,7 @@ class TokenizerManager:
self.server_args.load_format = obj.load_format self.server_args.load_format = obj.load_format
self.model_path = obj.model_path self.model_path = obj.model_path
return result.success, result.message return result.success, result.message
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.model_update_tmp = [] self.model_update_tmp = []
result = await self.model_update_result result = await self.model_update_result
...@@ -470,7 +474,7 @@ class TokenizerManager: ...@@ -470,7 +474,7 @@ class TokenizerManager:
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj) self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied # set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
...@@ -479,7 +483,7 @@ class TokenizerManager: ...@@ -479,7 +483,7 @@ class TokenizerManager:
elif isinstance(recv_obj, GetMemPoolSizeReqOutput): elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj) self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1 else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj) self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received # set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size: if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
......
...@@ -130,27 +130,65 @@ class Metrics: ...@@ -130,27 +130,65 @@ class Metrics:
self.counter_prompt_tokens = Counter( self.counter_prompt_tokens = Counter(
name="sglang:prompt_tokens_total", name="sglang:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
labelnames=labelnames) labelnames=labelnames,
)
self.counter_generation_tokens = Counter( self.counter_generation_tokens = Counter(
name="sglang:generation_tokens_total", name="sglang:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames,
)
self.histogram_time_to_first_token = Histogram( self.histogram_time_to_first_token = Histogram(
name="sglang:time_to_first_token_seconds", name="sglang:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[ buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.001,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 25.0, 30.0 0.005,
]) 0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
15.0,
20.0,
25.0,
30.0,
],
)
self.histogram_time_per_output_token = Histogram( self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds", name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation="Histogram of time per output token in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[ buckets=[
0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.04, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 0.005,
1.0, 2.5 0.01,
]) 0.015,
0.02,
0.025,
0.03,
0.04,
0.05,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
],
)
# Request Stats # Request Stats
# Metadata # Metadata
...@@ -245,14 +283,19 @@ class PrometheusMetricsCollector(MetricsCollector): ...@@ -245,14 +283,19 @@ class PrometheusMetricsCollector(MetricsCollector):
stats.num_generation_tokens_requests, stats.num_generation_tokens_requests,
) )
self._log_counter(self.metrics.counter_prompt_tokens, self._log_counter(
stats.num_prompt_tokens_iter) self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter
self._log_counter(self.metrics.counter_generation_tokens, )
stats.num_generation_tokens_iter) self._log_counter(
self._log_histogram(self.metrics.histogram_time_to_first_token, self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter
stats.time_to_first_tokens_iter) )
self._log_histogram(self.metrics.histogram_time_per_output_token, self._log_histogram(
stats.time_per_output_tokens_iter) self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter
)
self._log_histogram(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
)
# self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys) # self._log_gauge(self.metrics.gpu_cache_usage_sys, stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.num_running_sys, stats.num_running_req) self._log_gauge(self.metrics.num_running_sys, stats.num_running_req)
......
...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
#from sglang.srt.layers.activation import get_act_fn # from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module): ...@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
...@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module): ...@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.attn = RadixAttention(self.num_heads, self.attn = RadixAttention(
self.head_dim, self.num_heads,
scaling=self.scale, self.head_dim,
num_kv_heads=total_num_heads, scaling=self.scale,
layer_id=layer_id) num_kv_heads=total_num_heads,
layer_id=layer_id,
)
def forward( def forward(
self, self,
...@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module): ...@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(
intermediate_size) config.activation_function, quant_config, intermediate_size
)
def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor: def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states) hidden_states, _ = self.c_proj(hidden_states)
...@@ -135,27 +140,20 @@ class GPT2Block(nn.Module): ...@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 * inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(layer_id, self.attn = GPT2Attention(
config, layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
cache_config, )
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
...@@ -179,13 +177,12 @@ class GPT2Block(nn.Module): ...@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
return hidden_states return hidden_states
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module): ...@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config = None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, self.transformer = GPT2Model(
cache_config, config, cache_config, quant_config, prefix="transformer"
quant_config, )
prefix="transformer")
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module): ...@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module): ...@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = GPT2LMHeadModel EntryClass = GPT2LMHeadModel
...@@ -419,6 +419,7 @@ def launch_engine( ...@@ -419,6 +419,7 @@ def launch_engine(
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv() scheduler_pipe_readers[i].recv()
def add_prometheus_middleware(app: FastAPI): def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
...@@ -490,6 +491,7 @@ def launch_server( ...@@ -490,6 +491,7 @@ def launch_server(
finally: finally:
t.join() t.join()
def _set_prometheus_env(): def _set_prometheus_env():
# Set prometheus multiprocess directory # Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode # sglang uses prometheus multiprocess mode
...@@ -506,6 +508,7 @@ def _set_prometheus_env(): ...@@ -506,6 +508,7 @@ def _set_prometheus_env():
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
...@@ -763,8 +766,8 @@ class Engine: ...@@ -763,8 +766,8 @@ class Engine:
# runtime server default log level is log # runtime server default log level is log
# offline engine works in scripts, so we set it to error # offline engine works in scripts, so we set it to error
if 'log_level' not in kwargs: if "log_level" not in kwargs:
kwargs['log_level'] = 'error' kwargs["log_level"] = "error"
server_args = ServerArgs(*args, **kwargs) server_args = ServerArgs(*args, **kwargs)
launch_engine(server_args=server_args) launch_engine(server_args=server_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment