ZhongJingGPT_1_B.ipynb 19.2 KB
Newer Older
pariskang's avatar
pariskang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ZhongJingGPT-2-1.8b\n",
        "\n",
        "A Traditional Chinese Medicine large language model, inspired by the wisdom of the eminent representative of ancient Chinese medical scholars, Zhang Zhongjing. This model aims to illuminate the profound knowledge of Traditional Chinese Medicine, bridging the gap between ancient wisdom and modern technology, and providing a reliable and professional tool for the Traditional Chinese Medical fields. However, all generated results are for reference only and should be provided by experienced professionals for diagnosis and treatment results and suggestions."
      ],
      "metadata": {
        "id": "NKOuDx1olwGw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "print(torch.cuda.is_available())\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JM-BsUrpeWJT",
        "outputId": "8d593699-2995-452d-c8be-936fafa0249e"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "True\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers huggingface_hub accelerate peft"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WFDppZ3udxdz",
        "outputId": "eac02119-90b6-478d-eb81-eba0f0232491"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.37.2)\n",
            "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.20.3)\n",
            "Collecting accelerate\n",
            "  Downloading accelerate-0.27.2-py3-none-any.whl (279 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.0/280.0 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting peft\n",
            "  Downloading peft-0.8.2-py3-none-any.whl (183 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m183.4/183.4 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
            "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.9.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
            "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu121)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
            "Installing collected packages: accelerate, peft\n",
            "Successfully installed accelerate-0.27.2 peft-0.8.2\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# You should restart colab and the run the following code."
      ],
      "metadata": {
        "id": "PNSjkWCqmM8V"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "import torch\n",
        "\n",
        "# Set the device\n",
        "device = \"cuda\"  # replace with your device: \"cpu\", \"cuda\", \"mps\"\n",
        "\n",
        "# Initialize model and tokenizer\n",
        "peft_model_id = \"CMLL/ZhongJing-2-1_8b\"\n",
        "base_model_id = \"Qwen/Qwen1.5-1.8B-Chat\"\n",
        "model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map=\"auto\")\n",
        "model.load_adapter(peft_model_id)\n",
        "tokenizer = AutoTokenizer.from_pretrained(\n",
        "    \"CMLL/ZhongJing-2-1_8b\",\n",
        "    padding_side=\"right\",\n",
        "    trust_remote_code=True,\n",
        "    pad_token=''\n",
        ")\n",
        "\n",
        "def get_model_response(question, context):\n",
        "    # Create the prompt\n",
        "    prompt = f\"Question: {question}\\nContext: {context}\"\n",
        "    messages = [\n",
        "        {\"role\": \"system\", \"content\": \"You are a helpful TCM assistant named 仲景中医大语言模型.\"},\n",
        "        {\"role\": \"user\", \"content\": prompt}\n",
        "    ]\n",
        "\n",
        "    # Prepare the input\n",
        "    text = tokenizer.apply_chat_template(\n",
        "        messages,\n",
        "        tokenize=False,\n",
        "        add_generation_prompt=True\n",
        "    )\n",
        "    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
        "\n",
        "    # Generate the response\n",
        "    generated_ids = model.generate(\n",
        "        model_inputs.input_ids,\n",
        "        max_new_tokens=512\n",
        "    )\n",
        "    generated_ids = [\n",
        "        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
        "    ]\n",
        "\n",
        "    # Decode the response\n",
        "    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
        "    return response\n",
        "\n",
        "# Loop to get user input and provide model response\n",
        "while True:\n",
        "    user_question = input(\"Enter your question (or type 'exit' to stop): \")\n",
        "    if user_question.lower() == 'exit':\n",
        "        break\n",
        "    user_context = input(\"Enter context (or type 'none' if no context): \")\n",
        "    if user_context.lower() == 'none':\n",
        "        user_context = \"\"\n",
        "\n",
        "    print(\"Model is generating a response, please wait...\")\n",
        "    model_response = get_model_response(user_question, user_context)\n",
        "    print(\"Model's response:\", model_response)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "jsn4szdjdtmF",
        "outputId": "900e42e2-23be-4fb3-91d7-b586ba2d18a5"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n",
            "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
            "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Enter your question (or type 'exit' to stop): 你是谁\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 我是“仲景中医大语言模型”,也叫“仲景大医”或“大医大”。\n",
            "Enter your question (or type 'exit' to stop): 我发热,咳嗽,呼吸困难,给出中医诊断和处方\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: “发高热,恶寒,咳嗽,喘促不能平卧者”属于“肺热壅盛”,患者应该根据自身病情采取中西医结合的方法来进行治疗。中医常采用清热解毒的药物,如白虎汤、黄连阿胶汤、银翘散等;西药常用布洛芬、对乙酰氨基酚等进行退烧。当症状缓解时,可以使用疏风清热、宣肺降气的中药方剂,如麻杏石甘汤或加减葳蕤汤来治疗。\n",
            "Enter your question (or type 'exit' to stop): 我还咳嗽,痰黄,该怎么办\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 患者出现咳喘症状时,应及时就医。根据不同的病情和病因采取相应措施,如止咳、化痰、解表等。\n",
            "1、如果患者有发热等全身不适的表现,则可以使用消炎药、抗病毒药物进行治疗;\n",
            "2、如果患者是因感冒、肺炎导致的咳喘,则可以服用一些清热、解毒、止咳、化痰的中药,或者给予雾化吸入治疗;\n",
            "3、如果患者有咳嗽的症状,并伴有痰黄,可以在医生的指导下应用甘草合剂、川贝枇杷露、鲜竹沥液等药物进行治疗;\n",
            "4、如果是外感风寒所致的咳喘,则可以使用生姜、葱白、红糖煎水进行治疗,同时还可以应用疏风散寒、温肺化饮的中药。\n",
            "总之,在发生咳嗽的同时,还需结合具体的病因、临床表现和伴随的其他相关症状来综合分析,才能准确评估患者的病情并及时进行相应的治疗。\n",
            "Enter your question (or type 'exit' to stop): 谢谢\n",
            "Enter context (or type 'none' if no context): none\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model is generating a response, please wait...\n",
            "Model's response: 您好,非常感谢您的提问。\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "Interrupted by user",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-1-bd9c00e18996>\u001b[0m in \u001b[0;36m<cell line: 49>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[0;31m# Loop to get user input and provide model response\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m     \u001b[0muser_question\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Enter your question (or type 'exit' to stop): \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     51\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0muser_question\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'exit'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m         \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mraw_input\u001b[0;34m(self, prompt)\u001b[0m\n\u001b[1;32m    849\u001b[0m                 \u001b[0;34m\"raw_input was called, but this frontend does not support input requests.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    850\u001b[0m             )\n\u001b[0;32m--> 851\u001b[0;31m         return self._input_request(str(prompt),\n\u001b[0m\u001b[1;32m    852\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_ident\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    853\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_header\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36m_input_request\u001b[0;34m(self, prompt, ident, parent, password)\u001b[0m\n\u001b[1;32m    893\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    894\u001b[0m                 \u001b[0;31m# re-raise KeyboardInterrupt, to truncate traceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 895\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Interrupted by user\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    896\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    897\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Invalid Message:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_info\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: Interrupted by user"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q8pwg9UlitWI"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
313
}